blob: 15ea912d100d0da212c33972fcb148aa800a181e [file] [log] [blame]
Takahiro Suzukid7bf8202020-12-17 20:21:59 +09001package dnsutils
2
3import (
4 "math/rand"
5 "net"
6 "sort"
7)
8
9// OrderedSRV returns a count of the results and a map keyed on the order they should be used.
10// This based on the records' priority and randomised selection based on their relative weighting.
11// The function's inputs are the same as those for net.LookupSRV
12// To use in the correct order:
13//
14// count, orderedSRV, err := OrderedSRV(service, proto, name)
15// i := 1
16// for i <= count {
17// srv := orderedSRV[i]
18// // Do something such as dial this SRV. If fails move on the the next or break if it succeeds.
19// i += 1
20// }
21func OrderedSRV(service, proto, name string) (int, map[int]*net.SRV, error) {
22 _, addrs, err := net.LookupSRV(service, proto, name)
23 if err != nil {
24 return 0, make(map[int]*net.SRV), err
25 }
26 index, osrv := orderSRV(addrs)
27 return index, osrv, nil
28}
29
30func orderSRV(addrs []*net.SRV) (int, map[int]*net.SRV) {
31 // Initialise the ordered map
32 var o int
33 osrv := make(map[int]*net.SRV)
34
35 prioMap := make(map[int][]*net.SRV, 0)
36 for _, srv := range addrs {
37 prioMap[int(srv.Priority)] = append(prioMap[int(srv.Priority)], srv)
38 }
39
40 priorities := make([]int, 0)
41 for p := range prioMap {
42 priorities = append(priorities, p)
43 }
44
45 var count int
46 sort.Ints(priorities)
47 for _, p := range priorities {
48 tos := weightedOrder(prioMap[p])
49 for i, s := range tos {
50 count += 1
51 osrv[o+i] = s
52 }
53 o += len(tos)
54 }
55 return count, osrv
56}
57
58func weightedOrder(srvs []*net.SRV) map[int]*net.SRV {
59 // Get the total weight
60 var tw int
61 for _, s := range srvs {
62 tw += int(s.Weight)
63 }
64
65 // Initialise the ordered map
66 o := 1
67 osrv := make(map[int]*net.SRV)
68
69 // Whilst there are still entries to be ordered
70 l := len(srvs)
71 for l > 0 {
72 i := rand.Intn(l)
73 s := srvs[i]
74 var rw int
75 if tw > 0 {
76 // Greater the weight the more likely this will be zero or less
77 rw = rand.Intn(tw) - int(s.Weight)
78 }
79 if rw <= 0 {
80 // Put entry in position
81 osrv[o] = s
82 if len(srvs) > 1 {
83 // Remove the entry from the source slice by swapping with the last entry and truncating
84 srvs[len(srvs)-1], srvs[i] = srvs[i], srvs[len(srvs)-1]
85 srvs = srvs[:len(srvs)-1]
86 l = len(srvs)
87 } else {
88 l = 0
89 }
90 o += 1
91 tw = tw - int(s.Weight)
92 }
93 }
94 return osrv
95}