blob: 15ea912d100d0da212c33972fcb148aa800a181e [file] [log] [blame]
package dnsutils
import (
"math/rand"
"net"
"sort"
)
// OrderedSRV returns a count of the results and a map keyed on the order they should be used.
// This based on the records' priority and randomised selection based on their relative weighting.
// The function's inputs are the same as those for net.LookupSRV
// To use in the correct order:
//
// count, orderedSRV, err := OrderedSRV(service, proto, name)
// i := 1
// for i <= count {
// srv := orderedSRV[i]
// // Do something such as dial this SRV. If fails move on the the next or break if it succeeds.
// i += 1
// }
func OrderedSRV(service, proto, name string) (int, map[int]*net.SRV, error) {
_, addrs, err := net.LookupSRV(service, proto, name)
if err != nil {
return 0, make(map[int]*net.SRV), err
}
index, osrv := orderSRV(addrs)
return index, osrv, nil
}
func orderSRV(addrs []*net.SRV) (int, map[int]*net.SRV) {
// Initialise the ordered map
var o int
osrv := make(map[int]*net.SRV)
prioMap := make(map[int][]*net.SRV, 0)
for _, srv := range addrs {
prioMap[int(srv.Priority)] = append(prioMap[int(srv.Priority)], srv)
}
priorities := make([]int, 0)
for p := range prioMap {
priorities = append(priorities, p)
}
var count int
sort.Ints(priorities)
for _, p := range priorities {
tos := weightedOrder(prioMap[p])
for i, s := range tos {
count += 1
osrv[o+i] = s
}
o += len(tos)
}
return count, osrv
}
func weightedOrder(srvs []*net.SRV) map[int]*net.SRV {
// Get the total weight
var tw int
for _, s := range srvs {
tw += int(s.Weight)
}
// Initialise the ordered map
o := 1
osrv := make(map[int]*net.SRV)
// Whilst there are still entries to be ordered
l := len(srvs)
for l > 0 {
i := rand.Intn(l)
s := srvs[i]
var rw int
if tw > 0 {
// Greater the weight the more likely this will be zero or less
rw = rand.Intn(tw) - int(s.Weight)
}
if rw <= 0 {
// Put entry in position
osrv[o] = s
if len(srvs) > 1 {
// Remove the entry from the source slice by swapping with the last entry and truncating
srvs[len(srvs)-1], srvs[i] = srvs[i], srvs[len(srvs)-1]
srvs = srvs[:len(srvs)-1]
l = len(srvs)
} else {
l = 0
}
o += 1
tw = tw - int(s.Weight)
}
}
return osrv
}