Scott Baker | 611f6bd | 2019-10-18 13:45:19 -0700 | [diff] [blame] | 1 | package dnsutils |
| 2 | |
| 3 | import ( |
| 4 | "math/rand" |
| 5 | "net" |
| 6 | "sort" |
| 7 | ) |
| 8 | |
| 9 | // OrderedSRV returns a count of the results and a map keyed on the order they should be used. |
| 10 | // This based on the records' priority and randomised selection based on their relative weighting. |
| 11 | // The function's inputs are the same as those for net.LookupSRV |
| 12 | // To use in the correct order: |
| 13 | // |
| 14 | // count, orderedSRV, err := OrderedSRV(service, proto, name) |
| 15 | // i := 1 |
| 16 | // for i <= count { |
| 17 | // srv := orderedSRV[i] |
| 18 | // // Do something such as dial this SRV. If fails move on the the next or break if it succeeds. |
| 19 | // i += 1 |
| 20 | // } |
| 21 | func OrderedSRV(service, proto, name string) (int, map[int]*net.SRV, error) { |
| 22 | _, addrs, err := net.LookupSRV(service, proto, name) |
| 23 | if err != nil { |
| 24 | return 0, make(map[int]*net.SRV), err |
| 25 | } |
| 26 | index, osrv := orderSRV(addrs) |
| 27 | return index, osrv, nil |
| 28 | } |
| 29 | |
| 30 | func orderSRV(addrs []*net.SRV) (int, map[int]*net.SRV) { |
| 31 | // Initialise the ordered map |
| 32 | var o int |
| 33 | osrv := make(map[int]*net.SRV) |
| 34 | |
| 35 | prioMap := make(map[int][]*net.SRV, 0) |
| 36 | for _, srv := range addrs { |
| 37 | prioMap[int(srv.Priority)] = append(prioMap[int(srv.Priority)], srv) |
| 38 | } |
| 39 | |
| 40 | priorities := make([]int, 0) |
| 41 | for p := range prioMap { |
| 42 | priorities = append(priorities, p) |
| 43 | } |
| 44 | |
| 45 | var count int |
| 46 | sort.Ints(priorities) |
| 47 | for _, p := range priorities { |
| 48 | tos := weightedOrder(prioMap[p]) |
| 49 | for i, s := range tos { |
| 50 | count += 1 |
| 51 | osrv[o+i] = s |
| 52 | } |
| 53 | o += len(tos) |
| 54 | } |
| 55 | return count, osrv |
| 56 | } |
| 57 | |
| 58 | func weightedOrder(srvs []*net.SRV) map[int]*net.SRV { |
| 59 | // Get the total weight |
| 60 | var tw int |
| 61 | for _, s := range srvs { |
| 62 | tw += int(s.Weight) |
| 63 | } |
| 64 | |
| 65 | // Initialise the ordered map |
| 66 | o := 1 |
| 67 | osrv := make(map[int]*net.SRV) |
| 68 | |
| 69 | // Whilst there are still entries to be ordered |
| 70 | l := len(srvs) |
| 71 | for l > 0 { |
| 72 | i := rand.Intn(l) |
| 73 | s := srvs[i] |
| 74 | var rw int |
| 75 | if tw > 0 { |
| 76 | // Greater the weight the more likely this will be zero or less |
| 77 | rw = rand.Intn(tw) - int(s.Weight) |
| 78 | } |
| 79 | if rw <= 0 { |
| 80 | // Put entry in position |
| 81 | osrv[o] = s |
| 82 | if len(srvs) > 1 { |
| 83 | // Remove the entry from the source slice by swapping with the last entry and truncating |
| 84 | srvs[len(srvs)-1], srvs[i] = srvs[i], srvs[len(srvs)-1] |
| 85 | srvs = srvs[:len(srvs)-1] |
| 86 | l = len(srvs) |
| 87 | } else { |
| 88 | l = 0 |
| 89 | } |
| 90 | o += 1 |
| 91 | tw = tw - int(s.Weight) |
| 92 | } |
| 93 | } |
| 94 | return osrv |
| 95 | } |