blob: 9c31b6ec89bff91509dd2bb3eb69adf051da5ffd [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package description
8
9import (
10 "fmt"
11 "math"
12 "time"
13
14 "github.com/mongodb/mongo-go-driver/mongo/readpref"
15 "github.com/mongodb/mongo-go-driver/tag"
16)
17
18// ServerSelector is an interface implemented by types that can select a server given a
19// topology description.
20type ServerSelector interface {
21 SelectServer(Topology, []Server) ([]Server, error)
22}
23
24// ServerSelectorFunc is a function that can be used as a ServerSelector.
25type ServerSelectorFunc func(Topology, []Server) ([]Server, error)
26
27// SelectServer implements the ServerSelector interface.
28func (ssf ServerSelectorFunc) SelectServer(t Topology, s []Server) ([]Server, error) {
29 return ssf(t, s)
30}
31
32type compositeSelector struct {
33 selectors []ServerSelector
34}
35
36// CompositeSelector combines multiple selectors into a single selector.
37func CompositeSelector(selectors []ServerSelector) ServerSelector {
38 return &compositeSelector{selectors: selectors}
39}
40
41func (cs *compositeSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) {
42 var err error
43 for _, sel := range cs.selectors {
44 candidates, err = sel.SelectServer(t, candidates)
45 if err != nil {
46 return nil, err
47 }
48 }
49 return candidates, nil
50}
51
52type latencySelector struct {
53 latency time.Duration
54}
55
56// LatencySelector creates a ServerSelector which selects servers based on their latency.
57func LatencySelector(latency time.Duration) ServerSelector {
58 return &latencySelector{latency: latency}
59}
60
61func (ls *latencySelector) SelectServer(t Topology, candidates []Server) ([]Server, error) {
62 if ls.latency < 0 {
63 return candidates, nil
64 }
65
66 switch len(candidates) {
67 case 0, 1:
68 return candidates, nil
69 default:
70 min := time.Duration(math.MaxInt64)
71 for _, candidate := range candidates {
72 if candidate.AverageRTTSet {
73 if candidate.AverageRTT < min {
74 min = candidate.AverageRTT
75 }
76 }
77 }
78
79 if min == math.MaxInt64 {
80 return candidates, nil
81 }
82
83 max := min + ls.latency
84
85 var result []Server
86 for _, candidate := range candidates {
87 if candidate.AverageRTTSet {
88 if candidate.AverageRTT <= max {
89 result = append(result, candidate)
90 }
91 }
92 }
93
94 return result, nil
95 }
96}
97
98// WriteSelector selects all the writable servers.
99func WriteSelector() ServerSelector {
100 return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) {
101 switch t.Kind {
102 case Single:
103 return candidates, nil
104 default:
105 result := []Server{}
106 for _, candidate := range candidates {
107 switch candidate.Kind {
108 case Mongos, RSPrimary, Standalone:
109 result = append(result, candidate)
110 }
111 }
112 return result, nil
113 }
114 })
115}
116
117// ReadPrefSelector selects servers based on the provided read preference.
118func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector {
119 return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) {
120 if _, set := rp.MaxStaleness(); set {
121 for _, s := range candidates {
122 if s.Kind != Unknown {
123 if err := MaxStalenessSupported(s.WireVersion); err != nil {
124 return nil, err
125 }
126 }
127 }
128 }
129
130 switch t.Kind {
131 case Single:
132 return candidates, nil
133 case ReplicaSetNoPrimary, ReplicaSetWithPrimary:
134 return selectForReplicaSet(rp, t, candidates)
135 case Sharded:
136 return selectByKind(candidates, Mongos), nil
137 }
138
139 return nil, nil
140 })
141}
142
143func selectForReplicaSet(rp *readpref.ReadPref, t Topology, candidates []Server) ([]Server, error) {
144 if err := verifyMaxStaleness(rp, t); err != nil {
145 return nil, err
146 }
147
148 switch rp.Mode() {
149 case readpref.PrimaryMode:
150 return selectByKind(candidates, RSPrimary), nil
151 case readpref.PrimaryPreferredMode:
152 selected := selectByKind(candidates, RSPrimary)
153
154 if len(selected) == 0 {
155 selected = selectSecondaries(rp, candidates)
156 return selectByTagSet(selected, rp.TagSets()), nil
157 }
158
159 return selected, nil
160 case readpref.SecondaryPreferredMode:
161 selected := selectSecondaries(rp, candidates)
162 selected = selectByTagSet(selected, rp.TagSets())
163 if len(selected) > 0 {
164 return selected, nil
165 }
166 return selectByKind(candidates, RSPrimary), nil
167 case readpref.SecondaryMode:
168 selected := selectSecondaries(rp, candidates)
169 return selectByTagSet(selected, rp.TagSets()), nil
170 case readpref.NearestMode:
171 selected := selectByKind(candidates, RSPrimary)
172 selected = append(selected, selectSecondaries(rp, candidates)...)
173 return selectByTagSet(selected, rp.TagSets()), nil
174 }
175
176 return nil, fmt.Errorf("unsupported mode: %d", rp.Mode())
177}
178
179func selectSecondaries(rp *readpref.ReadPref, candidates []Server) []Server {
180 secondaries := selectByKind(candidates, RSSecondary)
181 if len(secondaries) == 0 {
182 return secondaries
183 }
184 if maxStaleness, set := rp.MaxStaleness(); set {
185 primaries := selectByKind(candidates, RSPrimary)
186 if len(primaries) == 0 {
187 baseTime := secondaries[0].LastWriteTime
188 for i := 1; i < len(secondaries); i++ {
189 if secondaries[i].LastWriteTime.After(baseTime) {
190 baseTime = secondaries[i].LastWriteTime
191 }
192 }
193
194 var selected []Server
195 for _, secondary := range secondaries {
196 estimatedStaleness := baseTime.Sub(secondary.LastWriteTime) + secondary.HeartbeatInterval
197 if estimatedStaleness <= maxStaleness {
198 selected = append(selected, secondary)
199 }
200 }
201
202 return selected
203 }
204
205 primary := primaries[0]
206
207 var selected []Server
208 for _, secondary := range secondaries {
209 estimatedStaleness := secondary.LastUpdateTime.Sub(secondary.LastWriteTime) - primary.LastUpdateTime.Sub(primary.LastWriteTime) + secondary.HeartbeatInterval
210 if estimatedStaleness <= maxStaleness {
211 selected = append(selected, secondary)
212 }
213 }
214 return selected
215 }
216
217 return secondaries
218}
219
220func selectByTagSet(candidates []Server, tagSets []tag.Set) []Server {
221 if len(tagSets) == 0 {
222 return candidates
223 }
224
225 for _, ts := range tagSets {
226 var results []Server
227 for _, s := range candidates {
228 if len(s.Tags) > 0 && s.Tags.ContainsAll(ts) {
229 results = append(results, s)
230 }
231 }
232
233 if len(results) > 0 {
234 return results
235 }
236 }
237
238 return []Server{}
239}
240
241func selectByKind(candidates []Server, kind ServerKind) []Server {
242 var result []Server
243 for _, s := range candidates {
244 if s.Kind == kind {
245 result = append(result, s)
246 }
247 }
248
249 return result
250}
251
252func verifyMaxStaleness(rp *readpref.ReadPref, t Topology) error {
253 maxStaleness, set := rp.MaxStaleness()
254 if !set {
255 return nil
256 }
257
258 if maxStaleness < 90*time.Second {
259 return fmt.Errorf("max staleness (%s) must be greater than or equal to 90s", maxStaleness)
260 }
261
262 if len(t.Servers) < 1 {
263 // Maybe we should return an error here instead?
264 return nil
265 }
266
267 // we'll assume all candidates have the same heartbeat interval.
268 s := t.Servers[0]
269 idleWritePeriod := 10 * time.Second
270
271 if maxStaleness < s.HeartbeatInterval+idleWritePeriod {
272 return fmt.Errorf(
273 "max staleness (%s) must be greater than or equal to the heartbeat interval (%s) plus idle write period (%s)",
274 maxStaleness, s.HeartbeatInterval, idleWritePeriod,
275 )
276 }
277
278 return nil
279}