blob: 0287ca70720371b399b89edc537ea5dd99689483 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package driver
8
9import (
10 "context"
11
12 "time"
13
14 "errors"
15
16 "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
17 "github.com/mongodb/mongo-go-driver/mongo/options"
18 "github.com/mongodb/mongo-go-driver/mongo/readpref"
19 "github.com/mongodb/mongo-go-driver/x/bsonx"
20 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
21 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
22 "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
23 "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
24 "github.com/mongodb/mongo-go-driver/x/network/command"
25 "github.com/mongodb/mongo-go-driver/x/network/connection"
26 "github.com/mongodb/mongo-go-driver/x/network/description"
27 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
28)
29
30// Find handles the full cycle dispatch and execution of a find command against the provided
31// topology.
32func Find(
33 ctx context.Context,
34 cmd command.Find,
35 topo *topology.Topology,
36 selector description.ServerSelector,
37 clientID uuid.UUID,
38 pool *session.Pool,
39 registry *bsoncodec.Registry,
40 opts ...*options.FindOptions,
41) (*BatchCursor, error) {
42
43 ss, err := topo.SelectServer(ctx, selector)
44 if err != nil {
45 return nil, err
46 }
47
48 desc := ss.Description()
49 conn, err := ss.Connection(ctx)
50 if err != nil {
51 return nil, err
52 }
53 defer conn.Close()
54
55 if desc.WireVersion.Max < 4 {
56 return legacyFind(ctx, cmd, registry, ss, conn, opts...)
57 }
58
59 rp, err := getReadPrefBasedOnTransaction(cmd.ReadPref, cmd.Session)
60 if err != nil {
61 return nil, err
62 }
63 cmd.ReadPref = rp
64
65 // If no explicit session and deployment supports sessions, start implicit session.
66 if cmd.Session == nil && topo.SupportsSessions() {
67 cmd.Session, err = session.NewClientSession(pool, clientID, session.Implicit)
68 if err != nil {
69 return nil, err
70 }
71 }
72
73 fo := options.MergeFindOptions(opts...)
74 if fo.AllowPartialResults != nil {
75 cmd.Opts = append(cmd.Opts, bsonx.Elem{"allowPartialResults", bsonx.Boolean(*fo.AllowPartialResults)})
76 }
77 if fo.BatchSize != nil {
78 elem := bsonx.Elem{"batchSize", bsonx.Int32(*fo.BatchSize)}
79 cmd.Opts = append(cmd.Opts, elem)
80 cmd.CursorOpts = append(cmd.CursorOpts, elem)
81
82 if fo.Limit != nil && *fo.BatchSize != 0 && *fo.Limit <= int64(*fo.BatchSize) {
83 cmd.Opts = append(cmd.Opts, bsonx.Elem{"singleBatch", bsonx.Boolean(true)})
84 }
85 }
86 if fo.Collation != nil {
87 if desc.WireVersion.Max < 5 {
88 return nil, ErrCollation
89 }
90 cmd.Opts = append(cmd.Opts, bsonx.Elem{"collation", bsonx.Document(fo.Collation.ToDocument())})
91 }
92 if fo.Comment != nil {
93 cmd.Opts = append(cmd.Opts, bsonx.Elem{"comment", bsonx.String(*fo.Comment)})
94 }
95 if fo.CursorType != nil {
96 switch *fo.CursorType {
97 case options.Tailable:
98 cmd.Opts = append(cmd.Opts, bsonx.Elem{"tailable", bsonx.Boolean(true)})
99 case options.TailableAwait:
100 cmd.Opts = append(cmd.Opts, bsonx.Elem{"tailable", bsonx.Boolean(true)}, bsonx.Elem{"awaitData", bsonx.Boolean(true)})
101 }
102 }
103 if fo.Hint != nil {
104 hintElem, err := interfaceToElement("hint", fo.Hint, registry)
105 if err != nil {
106 return nil, err
107 }
108
109 cmd.Opts = append(cmd.Opts, hintElem)
110 }
111 if fo.Limit != nil {
112 cmd.Opts = append(cmd.Opts, bsonx.Elem{"limit", bsonx.Int64(*fo.Limit)})
113 }
114 if fo.Max != nil {
115 maxElem, err := interfaceToElement("max", fo.Max, registry)
116 if err != nil {
117 return nil, err
118 }
119
120 cmd.Opts = append(cmd.Opts, maxElem)
121 }
122 if fo.MaxAwaitTime != nil {
123 // Specified as maxTimeMS on the in the getMore command and not given in initial find command.
124 cmd.CursorOpts = append(cmd.CursorOpts, bsonx.Elem{"maxTimeMS", bsonx.Int64(int64(*fo.MaxAwaitTime / time.Millisecond))})
125 }
126 if fo.MaxTime != nil {
127 cmd.Opts = append(cmd.Opts, bsonx.Elem{"maxTimeMS", bsonx.Int64(int64(*fo.MaxTime / time.Millisecond))})
128 }
129 if fo.Min != nil {
130 minElem, err := interfaceToElement("min", fo.Min, registry)
131 if err != nil {
132 return nil, err
133 }
134
135 cmd.Opts = append(cmd.Opts, minElem)
136 }
137 if fo.NoCursorTimeout != nil {
138 cmd.Opts = append(cmd.Opts, bsonx.Elem{"noCursorTimeout", bsonx.Boolean(*fo.NoCursorTimeout)})
139 }
140 if fo.OplogReplay != nil {
141 cmd.Opts = append(cmd.Opts, bsonx.Elem{"oplogReplay", bsonx.Boolean(*fo.OplogReplay)})
142 }
143 if fo.Projection != nil {
144 projElem, err := interfaceToElement("projection", fo.Projection, registry)
145 if err != nil {
146 return nil, err
147 }
148
149 cmd.Opts = append(cmd.Opts, projElem)
150 }
151 if fo.ReturnKey != nil {
152 cmd.Opts = append(cmd.Opts, bsonx.Elem{"returnKey", bsonx.Boolean(*fo.ReturnKey)})
153 }
154 if fo.ShowRecordID != nil {
155 cmd.Opts = append(cmd.Opts, bsonx.Elem{"showRecordId", bsonx.Boolean(*fo.ShowRecordID)})
156 }
157 if fo.Skip != nil {
158 cmd.Opts = append(cmd.Opts, bsonx.Elem{"skip", bsonx.Int64(*fo.Skip)})
159 }
160 if fo.Snapshot != nil {
161 cmd.Opts = append(cmd.Opts, bsonx.Elem{"snapshot", bsonx.Boolean(*fo.Snapshot)})
162 }
163 if fo.Sort != nil {
164 sortElem, err := interfaceToElement("sort", fo.Sort, registry)
165 if err != nil {
166 return nil, err
167 }
168
169 cmd.Opts = append(cmd.Opts, sortElem)
170 }
171
172 res, err := cmd.RoundTrip(ctx, desc, conn)
173 if err != nil {
174 closeImplicitSession(cmd.Session)
175 return nil, err
176 }
177
178 return NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...)
179}
180
181// legacyFind handles the dispatch and execution of a find operation against a pre-3.2 server.
182func legacyFind(
183 ctx context.Context,
184 cmd command.Find,
185 registry *bsoncodec.Registry,
186 ss *topology.SelectedServer,
187 conn connection.Connection,
188 opts ...*options.FindOptions,
189) (*BatchCursor, error) {
190 query := wiremessage.Query{
191 FullCollectionName: cmd.NS.DB + "." + cmd.NS.Collection,
192 }
193
194 fo := options.MergeFindOptions(opts...)
195 optsDoc, err := createLegacyOptionsDoc(fo, registry)
196 if err != nil {
197 return nil, err
198 }
199 if fo.Projection != nil {
200 projDoc, err := interfaceToDocument(fo.Projection, registry)
201 if err != nil {
202 return nil, err
203 }
204
205 projRaw, err := projDoc.MarshalBSON()
206 if err != nil {
207 return nil, err
208 }
209 query.ReturnFieldsSelector = projRaw
210 }
211 if fo.Skip != nil {
212 query.NumberToSkip = int32(*fo.Skip)
213 query.SkipSet = true
214 }
215 // batch size of 1 not possible with OP_QUERY because the cursor will be closed immediately
216 if fo.BatchSize != nil && *fo.BatchSize == 1 {
217 query.NumberToReturn = 2
218 } else {
219 query.NumberToReturn = calculateNumberToReturn(fo)
220 }
221 query.Flags = calculateLegacyFlags(fo)
222
223 query.BatchSize = fo.BatchSize
224 if fo.Limit != nil {
225 i := int32(*fo.Limit)
226 query.Limit = &i
227 }
228
229 // set read preference and/or slaveOK flag
230 desc := ss.SelectedDescription()
231 if slaveOkNeeded(cmd.ReadPref, desc) {
232 query.Flags |= wiremessage.SlaveOK
233 }
234 optsDoc = addReadPref(cmd.ReadPref, desc.Server.Kind, optsDoc)
235
236 if cmd.Filter == nil {
237 cmd.Filter = bsonx.Doc{}
238 }
239
240 // filter must be wrapped in $query if other $modifiers are used
241 var queryDoc bsonx.Doc
242 if len(optsDoc) == 0 {
243 queryDoc = cmd.Filter
244 } else {
245 filterDoc := bsonx.Doc{
246 {"$query", bsonx.Document(cmd.Filter)},
247 }
248 // $query should go first
249 queryDoc = append(filterDoc, optsDoc...)
250 }
251
252 queryRaw, err := queryDoc.MarshalBSON()
253 if err != nil {
254 return nil, err
255 }
256 query.Query = queryRaw
257
258 reply, err := roundTripQuery(ctx, query, conn)
259 if err != nil {
260 return nil, err
261 }
262
263 var cursorLimit int32
264 var cursorBatchSize int32
265 if query.Limit != nil {
266 cursorLimit = int32(*query.Limit)
267 if cursorLimit < 0 {
268 cursorLimit *= -1
269 }
270 }
271 if query.BatchSize != nil {
272 cursorBatchSize = int32(*query.BatchSize)
273 }
274
275 return NewLegacyBatchCursor(cmd.NS, reply.CursorID, reply.Documents, cursorLimit, cursorBatchSize, ss.Server)
276}
277
278func createLegacyOptionsDoc(fo *options.FindOptions, registry *bsoncodec.Registry) (bsonx.Doc, error) {
279 var optsDoc bsonx.Doc
280
281 if fo.Collation != nil {
282 return nil, ErrCollation
283 }
284 if fo.Comment != nil {
285 optsDoc = append(optsDoc, bsonx.Elem{"$comment", bsonx.String(*fo.Comment)})
286 }
287 if fo.Hint != nil {
288 hintElem, err := interfaceToElement("$hint", fo.Hint, registry)
289 if err != nil {
290 return nil, err
291 }
292
293 optsDoc = append(optsDoc, hintElem)
294 }
295 if fo.Max != nil {
296 maxElem, err := interfaceToElement("$max", fo.Max, registry)
297 if err != nil {
298 return nil, err
299 }
300
301 optsDoc = append(optsDoc, maxElem)
302 }
303 if fo.MaxTime != nil {
304 optsDoc = append(optsDoc, bsonx.Elem{"$maxTimeMS", bsonx.Int64(int64(*fo.MaxTime / time.Millisecond))})
305 }
306 if fo.Min != nil {
307 minElem, err := interfaceToElement("$min", fo.Min, registry)
308 if err != nil {
309 return nil, err
310 }
311
312 optsDoc = append(optsDoc, minElem)
313 }
314 if fo.ReturnKey != nil {
315 optsDoc = append(optsDoc, bsonx.Elem{"$returnKey", bsonx.Boolean(*fo.ReturnKey)})
316 }
317 if fo.ShowRecordID != nil {
318 optsDoc = append(optsDoc, bsonx.Elem{"$showDiskLoc", bsonx.Boolean(*fo.ShowRecordID)})
319 }
320 if fo.Snapshot != nil {
321 optsDoc = append(optsDoc, bsonx.Elem{"$snapshot", bsonx.Boolean(*fo.Snapshot)})
322 }
323 if fo.Sort != nil {
324 sortElem, err := interfaceToElement("$orderby", fo.Sort, registry)
325 if err != nil {
326 return nil, err
327 }
328
329 optsDoc = append(optsDoc, sortElem)
330 }
331
332 return optsDoc, nil
333}
334
335func calculateLegacyFlags(fo *options.FindOptions) wiremessage.QueryFlag {
336 var flags wiremessage.QueryFlag
337
338 if fo.AllowPartialResults != nil {
339 flags |= wiremessage.Partial
340 }
341 if fo.CursorType != nil {
342 switch *fo.CursorType {
343 case options.Tailable:
344 flags |= wiremessage.TailableCursor
345 case options.TailableAwait:
346 flags |= wiremessage.TailableCursor
347 flags |= wiremessage.AwaitData
348 }
349 }
350 if fo.NoCursorTimeout != nil {
351 flags |= wiremessage.NoCursorTimeout
352 }
353 if fo.OplogReplay != nil {
354 flags |= wiremessage.OplogReplay
355 }
356
357 return flags
358}
359
360// calculate the number to return for the first find query
361func calculateNumberToReturn(opts *options.FindOptions) int32 {
362 var numReturn int32
363 var limit int32
364 var batchSize int32
365
366 if opts.Limit != nil {
367 limit = int32(*opts.Limit)
368 }
369 if opts.BatchSize != nil {
370 batchSize = int32(*opts.BatchSize)
371 }
372
373 if limit < 0 {
374 numReturn = limit
375 } else if limit == 0 {
376 numReturn = batchSize
377 } else if limit < batchSize {
378 numReturn = limit
379 } else {
380 numReturn = batchSize
381 }
382
383 return numReturn
384}
385
386func slaveOkNeeded(rp *readpref.ReadPref, desc description.SelectedServer) bool {
387 if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
388 return true
389 }
390 if rp == nil {
391 // assume primary
392 return false
393 }
394
395 return rp.Mode() != readpref.PrimaryMode
396}
397
398func addReadPref(rp *readpref.ReadPref, kind description.ServerKind, query bsonx.Doc) bsonx.Doc {
399 if !readPrefNeeded(rp, kind) {
400 return query
401 }
402
403 doc := createReadPref(rp)
404 if doc == nil {
405 return query
406 }
407
408 return query.Append("$readPreference", bsonx.Document(doc))
409}
410
411func readPrefNeeded(rp *readpref.ReadPref, kind description.ServerKind) bool {
412 if kind != description.Mongos || rp == nil {
413 return false
414 }
415
416 // simple Primary or SecondaryPreferred is communicated via slaveOk to Mongos.
417 if rp.Mode() == readpref.PrimaryMode || rp.Mode() == readpref.SecondaryPreferredMode {
418 if _, ok := rp.MaxStaleness(); !ok && len(rp.TagSets()) == 0 {
419 return false
420 }
421 }
422
423 return true
424}
425
426func createReadPref(rp *readpref.ReadPref) bsonx.Doc {
427 if rp == nil {
428 return nil
429 }
430
431 doc := bsonx.Doc{}
432
433 switch rp.Mode() {
434 case readpref.PrimaryMode:
435 doc = append(doc, bsonx.Elem{"mode", bsonx.String("primary")})
436 case readpref.PrimaryPreferredMode:
437 doc = append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
438 case readpref.SecondaryPreferredMode:
439 doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondaryPreferred")})
440 case readpref.SecondaryMode:
441 doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondary")})
442 case readpref.NearestMode:
443 doc = append(doc, bsonx.Elem{"mode", bsonx.String("nearest")})
444 }
445
446 sets := make([]bsonx.Val, 0, len(rp.TagSets()))
447 for _, ts := range rp.TagSets() {
448 if len(ts) == 0 {
449 continue
450 }
451 set := bsonx.Doc{}
452 for _, t := range ts {
453 set = append(set, bsonx.Elem{t.Name, bsonx.String(t.Value)})
454 }
455 sets = append(sets, bsonx.Document(set))
456 }
457 if len(sets) > 0 {
458 doc = append(doc, bsonx.Elem{"tags", bsonx.Array(sets)})
459 }
460 if d, ok := rp.MaxStaleness(); ok {
461 doc = append(doc, bsonx.Elem{"maxStalenessSeconds", bsonx.Int32(int32(d.Seconds()))})
462 }
463
464 return doc
465}
466
467func roundTripQuery(ctx context.Context, query wiremessage.Query, conn connection.Connection) (wiremessage.Reply, error) {
468 err := conn.WriteWireMessage(ctx, query)
469 if err != nil {
470 if _, ok := err.(command.Error); ok {
471 return wiremessage.Reply{}, err
472 }
473 return wiremessage.Reply{}, command.Error{
474 Message: err.Error(),
475 Labels: []string{command.NetworkError},
476 }
477 }
478
479 wm, err := conn.ReadWireMessage(ctx)
480 if err != nil {
481 if _, ok := err.(command.Error); ok {
482 return wiremessage.Reply{}, err
483 }
484 // Connection errors are transient
485 return wiremessage.Reply{}, command.Error{
486 Message: err.Error(),
487 Labels: []string{command.NetworkError},
488 }
489 }
490
491 reply, ok := wm.(wiremessage.Reply)
492 if !ok {
493 return wiremessage.Reply{}, errors.New("did not receive OP_REPLY response")
494 }
495
496 err = validateOpReply(reply)
497 if err != nil {
498 return wiremessage.Reply{}, err
499 }
500
501 return reply, nil
502}
503
504func validateOpReply(reply wiremessage.Reply) error {
505 if int(reply.NumberReturned) != len(reply.Documents) {
506 return command.NewCommandResponseError(command.ReplyDocumentMismatch, nil)
507 }
508
509 if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
510 return command.QueryFailureError{
511 Message: "query failure",
512 Response: reply.Documents[0],
513 }
514 }
515
516 return nil
517}