Don Newton | 379ae25 | 2019-04-01 12:17:06 -0400 | [diff] [blame^] | 1 | // Copyright (C) MongoDB, Inc. 2017-present. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may |
| 4 | // not use this file except in compliance with the License. You may obtain |
| 5 | // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | |
| 7 | package driver |
| 8 | |
| 9 | import ( |
| 10 | "context" |
| 11 | |
| 12 | "time" |
| 13 | |
| 14 | "errors" |
| 15 | |
| 16 | "github.com/mongodb/mongo-go-driver/bson/bsoncodec" |
| 17 | "github.com/mongodb/mongo-go-driver/mongo/options" |
| 18 | "github.com/mongodb/mongo-go-driver/mongo/readpref" |
| 19 | "github.com/mongodb/mongo-go-driver/x/bsonx" |
| 20 | "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore" |
| 21 | "github.com/mongodb/mongo-go-driver/x/mongo/driver/session" |
| 22 | "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology" |
| 23 | "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid" |
| 24 | "github.com/mongodb/mongo-go-driver/x/network/command" |
| 25 | "github.com/mongodb/mongo-go-driver/x/network/connection" |
| 26 | "github.com/mongodb/mongo-go-driver/x/network/description" |
| 27 | "github.com/mongodb/mongo-go-driver/x/network/wiremessage" |
| 28 | ) |
| 29 | |
| 30 | // Find handles the full cycle dispatch and execution of a find command against the provided |
| 31 | // topology. |
| 32 | func Find( |
| 33 | ctx context.Context, |
| 34 | cmd command.Find, |
| 35 | topo *topology.Topology, |
| 36 | selector description.ServerSelector, |
| 37 | clientID uuid.UUID, |
| 38 | pool *session.Pool, |
| 39 | registry *bsoncodec.Registry, |
| 40 | opts ...*options.FindOptions, |
| 41 | ) (*BatchCursor, error) { |
| 42 | |
| 43 | ss, err := topo.SelectServer(ctx, selector) |
| 44 | if err != nil { |
| 45 | return nil, err |
| 46 | } |
| 47 | |
| 48 | desc := ss.Description() |
| 49 | conn, err := ss.Connection(ctx) |
| 50 | if err != nil { |
| 51 | return nil, err |
| 52 | } |
| 53 | defer conn.Close() |
| 54 | |
| 55 | if desc.WireVersion.Max < 4 { |
| 56 | return legacyFind(ctx, cmd, registry, ss, conn, opts...) |
| 57 | } |
| 58 | |
| 59 | rp, err := getReadPrefBasedOnTransaction(cmd.ReadPref, cmd.Session) |
| 60 | if err != nil { |
| 61 | return nil, err |
| 62 | } |
| 63 | cmd.ReadPref = rp |
| 64 | |
| 65 | // If no explicit session and deployment supports sessions, start implicit session. |
| 66 | if cmd.Session == nil && topo.SupportsSessions() { |
| 67 | cmd.Session, err = session.NewClientSession(pool, clientID, session.Implicit) |
| 68 | if err != nil { |
| 69 | return nil, err |
| 70 | } |
| 71 | } |
| 72 | |
| 73 | fo := options.MergeFindOptions(opts...) |
| 74 | if fo.AllowPartialResults != nil { |
| 75 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"allowPartialResults", bsonx.Boolean(*fo.AllowPartialResults)}) |
| 76 | } |
| 77 | if fo.BatchSize != nil { |
| 78 | elem := bsonx.Elem{"batchSize", bsonx.Int32(*fo.BatchSize)} |
| 79 | cmd.Opts = append(cmd.Opts, elem) |
| 80 | cmd.CursorOpts = append(cmd.CursorOpts, elem) |
| 81 | |
| 82 | if fo.Limit != nil && *fo.BatchSize != 0 && *fo.Limit <= int64(*fo.BatchSize) { |
| 83 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"singleBatch", bsonx.Boolean(true)}) |
| 84 | } |
| 85 | } |
| 86 | if fo.Collation != nil { |
| 87 | if desc.WireVersion.Max < 5 { |
| 88 | return nil, ErrCollation |
| 89 | } |
| 90 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"collation", bsonx.Document(fo.Collation.ToDocument())}) |
| 91 | } |
| 92 | if fo.Comment != nil { |
| 93 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"comment", bsonx.String(*fo.Comment)}) |
| 94 | } |
| 95 | if fo.CursorType != nil { |
| 96 | switch *fo.CursorType { |
| 97 | case options.Tailable: |
| 98 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"tailable", bsonx.Boolean(true)}) |
| 99 | case options.TailableAwait: |
| 100 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"tailable", bsonx.Boolean(true)}, bsonx.Elem{"awaitData", bsonx.Boolean(true)}) |
| 101 | } |
| 102 | } |
| 103 | if fo.Hint != nil { |
| 104 | hintElem, err := interfaceToElement("hint", fo.Hint, registry) |
| 105 | if err != nil { |
| 106 | return nil, err |
| 107 | } |
| 108 | |
| 109 | cmd.Opts = append(cmd.Opts, hintElem) |
| 110 | } |
| 111 | if fo.Limit != nil { |
| 112 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"limit", bsonx.Int64(*fo.Limit)}) |
| 113 | } |
| 114 | if fo.Max != nil { |
| 115 | maxElem, err := interfaceToElement("max", fo.Max, registry) |
| 116 | if err != nil { |
| 117 | return nil, err |
| 118 | } |
| 119 | |
| 120 | cmd.Opts = append(cmd.Opts, maxElem) |
| 121 | } |
| 122 | if fo.MaxAwaitTime != nil { |
| 123 | // Specified as maxTimeMS on the in the getMore command and not given in initial find command. |
| 124 | cmd.CursorOpts = append(cmd.CursorOpts, bsonx.Elem{"maxTimeMS", bsonx.Int64(int64(*fo.MaxAwaitTime / time.Millisecond))}) |
| 125 | } |
| 126 | if fo.MaxTime != nil { |
| 127 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"maxTimeMS", bsonx.Int64(int64(*fo.MaxTime / time.Millisecond))}) |
| 128 | } |
| 129 | if fo.Min != nil { |
| 130 | minElem, err := interfaceToElement("min", fo.Min, registry) |
| 131 | if err != nil { |
| 132 | return nil, err |
| 133 | } |
| 134 | |
| 135 | cmd.Opts = append(cmd.Opts, minElem) |
| 136 | } |
| 137 | if fo.NoCursorTimeout != nil { |
| 138 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"noCursorTimeout", bsonx.Boolean(*fo.NoCursorTimeout)}) |
| 139 | } |
| 140 | if fo.OplogReplay != nil { |
| 141 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"oplogReplay", bsonx.Boolean(*fo.OplogReplay)}) |
| 142 | } |
| 143 | if fo.Projection != nil { |
| 144 | projElem, err := interfaceToElement("projection", fo.Projection, registry) |
| 145 | if err != nil { |
| 146 | return nil, err |
| 147 | } |
| 148 | |
| 149 | cmd.Opts = append(cmd.Opts, projElem) |
| 150 | } |
| 151 | if fo.ReturnKey != nil { |
| 152 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"returnKey", bsonx.Boolean(*fo.ReturnKey)}) |
| 153 | } |
| 154 | if fo.ShowRecordID != nil { |
| 155 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"showRecordId", bsonx.Boolean(*fo.ShowRecordID)}) |
| 156 | } |
| 157 | if fo.Skip != nil { |
| 158 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"skip", bsonx.Int64(*fo.Skip)}) |
| 159 | } |
| 160 | if fo.Snapshot != nil { |
| 161 | cmd.Opts = append(cmd.Opts, bsonx.Elem{"snapshot", bsonx.Boolean(*fo.Snapshot)}) |
| 162 | } |
| 163 | if fo.Sort != nil { |
| 164 | sortElem, err := interfaceToElement("sort", fo.Sort, registry) |
| 165 | if err != nil { |
| 166 | return nil, err |
| 167 | } |
| 168 | |
| 169 | cmd.Opts = append(cmd.Opts, sortElem) |
| 170 | } |
| 171 | |
| 172 | res, err := cmd.RoundTrip(ctx, desc, conn) |
| 173 | if err != nil { |
| 174 | closeImplicitSession(cmd.Session) |
| 175 | return nil, err |
| 176 | } |
| 177 | |
| 178 | return NewBatchCursor(bsoncore.Document(res), cmd.Session, cmd.Clock, ss.Server, cmd.CursorOpts...) |
| 179 | } |
| 180 | |
| 181 | // legacyFind handles the dispatch and execution of a find operation against a pre-3.2 server. |
| 182 | func legacyFind( |
| 183 | ctx context.Context, |
| 184 | cmd command.Find, |
| 185 | registry *bsoncodec.Registry, |
| 186 | ss *topology.SelectedServer, |
| 187 | conn connection.Connection, |
| 188 | opts ...*options.FindOptions, |
| 189 | ) (*BatchCursor, error) { |
| 190 | query := wiremessage.Query{ |
| 191 | FullCollectionName: cmd.NS.DB + "." + cmd.NS.Collection, |
| 192 | } |
| 193 | |
| 194 | fo := options.MergeFindOptions(opts...) |
| 195 | optsDoc, err := createLegacyOptionsDoc(fo, registry) |
| 196 | if err != nil { |
| 197 | return nil, err |
| 198 | } |
| 199 | if fo.Projection != nil { |
| 200 | projDoc, err := interfaceToDocument(fo.Projection, registry) |
| 201 | if err != nil { |
| 202 | return nil, err |
| 203 | } |
| 204 | |
| 205 | projRaw, err := projDoc.MarshalBSON() |
| 206 | if err != nil { |
| 207 | return nil, err |
| 208 | } |
| 209 | query.ReturnFieldsSelector = projRaw |
| 210 | } |
| 211 | if fo.Skip != nil { |
| 212 | query.NumberToSkip = int32(*fo.Skip) |
| 213 | query.SkipSet = true |
| 214 | } |
| 215 | // batch size of 1 not possible with OP_QUERY because the cursor will be closed immediately |
| 216 | if fo.BatchSize != nil && *fo.BatchSize == 1 { |
| 217 | query.NumberToReturn = 2 |
| 218 | } else { |
| 219 | query.NumberToReturn = calculateNumberToReturn(fo) |
| 220 | } |
| 221 | query.Flags = calculateLegacyFlags(fo) |
| 222 | |
| 223 | query.BatchSize = fo.BatchSize |
| 224 | if fo.Limit != nil { |
| 225 | i := int32(*fo.Limit) |
| 226 | query.Limit = &i |
| 227 | } |
| 228 | |
| 229 | // set read preference and/or slaveOK flag |
| 230 | desc := ss.SelectedDescription() |
| 231 | if slaveOkNeeded(cmd.ReadPref, desc) { |
| 232 | query.Flags |= wiremessage.SlaveOK |
| 233 | } |
| 234 | optsDoc = addReadPref(cmd.ReadPref, desc.Server.Kind, optsDoc) |
| 235 | |
| 236 | if cmd.Filter == nil { |
| 237 | cmd.Filter = bsonx.Doc{} |
| 238 | } |
| 239 | |
| 240 | // filter must be wrapped in $query if other $modifiers are used |
| 241 | var queryDoc bsonx.Doc |
| 242 | if len(optsDoc) == 0 { |
| 243 | queryDoc = cmd.Filter |
| 244 | } else { |
| 245 | filterDoc := bsonx.Doc{ |
| 246 | {"$query", bsonx.Document(cmd.Filter)}, |
| 247 | } |
| 248 | // $query should go first |
| 249 | queryDoc = append(filterDoc, optsDoc...) |
| 250 | } |
| 251 | |
| 252 | queryRaw, err := queryDoc.MarshalBSON() |
| 253 | if err != nil { |
| 254 | return nil, err |
| 255 | } |
| 256 | query.Query = queryRaw |
| 257 | |
| 258 | reply, err := roundTripQuery(ctx, query, conn) |
| 259 | if err != nil { |
| 260 | return nil, err |
| 261 | } |
| 262 | |
| 263 | var cursorLimit int32 |
| 264 | var cursorBatchSize int32 |
| 265 | if query.Limit != nil { |
| 266 | cursorLimit = int32(*query.Limit) |
| 267 | if cursorLimit < 0 { |
| 268 | cursorLimit *= -1 |
| 269 | } |
| 270 | } |
| 271 | if query.BatchSize != nil { |
| 272 | cursorBatchSize = int32(*query.BatchSize) |
| 273 | } |
| 274 | |
| 275 | return NewLegacyBatchCursor(cmd.NS, reply.CursorID, reply.Documents, cursorLimit, cursorBatchSize, ss.Server) |
| 276 | } |
| 277 | |
| 278 | func createLegacyOptionsDoc(fo *options.FindOptions, registry *bsoncodec.Registry) (bsonx.Doc, error) { |
| 279 | var optsDoc bsonx.Doc |
| 280 | |
| 281 | if fo.Collation != nil { |
| 282 | return nil, ErrCollation |
| 283 | } |
| 284 | if fo.Comment != nil { |
| 285 | optsDoc = append(optsDoc, bsonx.Elem{"$comment", bsonx.String(*fo.Comment)}) |
| 286 | } |
| 287 | if fo.Hint != nil { |
| 288 | hintElem, err := interfaceToElement("$hint", fo.Hint, registry) |
| 289 | if err != nil { |
| 290 | return nil, err |
| 291 | } |
| 292 | |
| 293 | optsDoc = append(optsDoc, hintElem) |
| 294 | } |
| 295 | if fo.Max != nil { |
| 296 | maxElem, err := interfaceToElement("$max", fo.Max, registry) |
| 297 | if err != nil { |
| 298 | return nil, err |
| 299 | } |
| 300 | |
| 301 | optsDoc = append(optsDoc, maxElem) |
| 302 | } |
| 303 | if fo.MaxTime != nil { |
| 304 | optsDoc = append(optsDoc, bsonx.Elem{"$maxTimeMS", bsonx.Int64(int64(*fo.MaxTime / time.Millisecond))}) |
| 305 | } |
| 306 | if fo.Min != nil { |
| 307 | minElem, err := interfaceToElement("$min", fo.Min, registry) |
| 308 | if err != nil { |
| 309 | return nil, err |
| 310 | } |
| 311 | |
| 312 | optsDoc = append(optsDoc, minElem) |
| 313 | } |
| 314 | if fo.ReturnKey != nil { |
| 315 | optsDoc = append(optsDoc, bsonx.Elem{"$returnKey", bsonx.Boolean(*fo.ReturnKey)}) |
| 316 | } |
| 317 | if fo.ShowRecordID != nil { |
| 318 | optsDoc = append(optsDoc, bsonx.Elem{"$showDiskLoc", bsonx.Boolean(*fo.ShowRecordID)}) |
| 319 | } |
| 320 | if fo.Snapshot != nil { |
| 321 | optsDoc = append(optsDoc, bsonx.Elem{"$snapshot", bsonx.Boolean(*fo.Snapshot)}) |
| 322 | } |
| 323 | if fo.Sort != nil { |
| 324 | sortElem, err := interfaceToElement("$orderby", fo.Sort, registry) |
| 325 | if err != nil { |
| 326 | return nil, err |
| 327 | } |
| 328 | |
| 329 | optsDoc = append(optsDoc, sortElem) |
| 330 | } |
| 331 | |
| 332 | return optsDoc, nil |
| 333 | } |
| 334 | |
| 335 | func calculateLegacyFlags(fo *options.FindOptions) wiremessage.QueryFlag { |
| 336 | var flags wiremessage.QueryFlag |
| 337 | |
| 338 | if fo.AllowPartialResults != nil { |
| 339 | flags |= wiremessage.Partial |
| 340 | } |
| 341 | if fo.CursorType != nil { |
| 342 | switch *fo.CursorType { |
| 343 | case options.Tailable: |
| 344 | flags |= wiremessage.TailableCursor |
| 345 | case options.TailableAwait: |
| 346 | flags |= wiremessage.TailableCursor |
| 347 | flags |= wiremessage.AwaitData |
| 348 | } |
| 349 | } |
| 350 | if fo.NoCursorTimeout != nil { |
| 351 | flags |= wiremessage.NoCursorTimeout |
| 352 | } |
| 353 | if fo.OplogReplay != nil { |
| 354 | flags |= wiremessage.OplogReplay |
| 355 | } |
| 356 | |
| 357 | return flags |
| 358 | } |
| 359 | |
| 360 | // calculate the number to return for the first find query |
| 361 | func calculateNumberToReturn(opts *options.FindOptions) int32 { |
| 362 | var numReturn int32 |
| 363 | var limit int32 |
| 364 | var batchSize int32 |
| 365 | |
| 366 | if opts.Limit != nil { |
| 367 | limit = int32(*opts.Limit) |
| 368 | } |
| 369 | if opts.BatchSize != nil { |
| 370 | batchSize = int32(*opts.BatchSize) |
| 371 | } |
| 372 | |
| 373 | if limit < 0 { |
| 374 | numReturn = limit |
| 375 | } else if limit == 0 { |
| 376 | numReturn = batchSize |
| 377 | } else if limit < batchSize { |
| 378 | numReturn = limit |
| 379 | } else { |
| 380 | numReturn = batchSize |
| 381 | } |
| 382 | |
| 383 | return numReturn |
| 384 | } |
| 385 | |
| 386 | func slaveOkNeeded(rp *readpref.ReadPref, desc description.SelectedServer) bool { |
| 387 | if desc.Kind == description.Single && desc.Server.Kind != description.Mongos { |
| 388 | return true |
| 389 | } |
| 390 | if rp == nil { |
| 391 | // assume primary |
| 392 | return false |
| 393 | } |
| 394 | |
| 395 | return rp.Mode() != readpref.PrimaryMode |
| 396 | } |
| 397 | |
| 398 | func addReadPref(rp *readpref.ReadPref, kind description.ServerKind, query bsonx.Doc) bsonx.Doc { |
| 399 | if !readPrefNeeded(rp, kind) { |
| 400 | return query |
| 401 | } |
| 402 | |
| 403 | doc := createReadPref(rp) |
| 404 | if doc == nil { |
| 405 | return query |
| 406 | } |
| 407 | |
| 408 | return query.Append("$readPreference", bsonx.Document(doc)) |
| 409 | } |
| 410 | |
| 411 | func readPrefNeeded(rp *readpref.ReadPref, kind description.ServerKind) bool { |
| 412 | if kind != description.Mongos || rp == nil { |
| 413 | return false |
| 414 | } |
| 415 | |
| 416 | // simple Primary or SecondaryPreferred is communicated via slaveOk to Mongos. |
| 417 | if rp.Mode() == readpref.PrimaryMode || rp.Mode() == readpref.SecondaryPreferredMode { |
| 418 | if _, ok := rp.MaxStaleness(); !ok && len(rp.TagSets()) == 0 { |
| 419 | return false |
| 420 | } |
| 421 | } |
| 422 | |
| 423 | return true |
| 424 | } |
| 425 | |
| 426 | func createReadPref(rp *readpref.ReadPref) bsonx.Doc { |
| 427 | if rp == nil { |
| 428 | return nil |
| 429 | } |
| 430 | |
| 431 | doc := bsonx.Doc{} |
| 432 | |
| 433 | switch rp.Mode() { |
| 434 | case readpref.PrimaryMode: |
| 435 | doc = append(doc, bsonx.Elem{"mode", bsonx.String("primary")}) |
| 436 | case readpref.PrimaryPreferredMode: |
| 437 | doc = append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")}) |
| 438 | case readpref.SecondaryPreferredMode: |
| 439 | doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondaryPreferred")}) |
| 440 | case readpref.SecondaryMode: |
| 441 | doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondary")}) |
| 442 | case readpref.NearestMode: |
| 443 | doc = append(doc, bsonx.Elem{"mode", bsonx.String("nearest")}) |
| 444 | } |
| 445 | |
| 446 | sets := make([]bsonx.Val, 0, len(rp.TagSets())) |
| 447 | for _, ts := range rp.TagSets() { |
| 448 | if len(ts) == 0 { |
| 449 | continue |
| 450 | } |
| 451 | set := bsonx.Doc{} |
| 452 | for _, t := range ts { |
| 453 | set = append(set, bsonx.Elem{t.Name, bsonx.String(t.Value)}) |
| 454 | } |
| 455 | sets = append(sets, bsonx.Document(set)) |
| 456 | } |
| 457 | if len(sets) > 0 { |
| 458 | doc = append(doc, bsonx.Elem{"tags", bsonx.Array(sets)}) |
| 459 | } |
| 460 | if d, ok := rp.MaxStaleness(); ok { |
| 461 | doc = append(doc, bsonx.Elem{"maxStalenessSeconds", bsonx.Int32(int32(d.Seconds()))}) |
| 462 | } |
| 463 | |
| 464 | return doc |
| 465 | } |
| 466 | |
| 467 | func roundTripQuery(ctx context.Context, query wiremessage.Query, conn connection.Connection) (wiremessage.Reply, error) { |
| 468 | err := conn.WriteWireMessage(ctx, query) |
| 469 | if err != nil { |
| 470 | if _, ok := err.(command.Error); ok { |
| 471 | return wiremessage.Reply{}, err |
| 472 | } |
| 473 | return wiremessage.Reply{}, command.Error{ |
| 474 | Message: err.Error(), |
| 475 | Labels: []string{command.NetworkError}, |
| 476 | } |
| 477 | } |
| 478 | |
| 479 | wm, err := conn.ReadWireMessage(ctx) |
| 480 | if err != nil { |
| 481 | if _, ok := err.(command.Error); ok { |
| 482 | return wiremessage.Reply{}, err |
| 483 | } |
| 484 | // Connection errors are transient |
| 485 | return wiremessage.Reply{}, command.Error{ |
| 486 | Message: err.Error(), |
| 487 | Labels: []string{command.NetworkError}, |
| 488 | } |
| 489 | } |
| 490 | |
| 491 | reply, ok := wm.(wiremessage.Reply) |
| 492 | if !ok { |
| 493 | return wiremessage.Reply{}, errors.New("did not receive OP_REPLY response") |
| 494 | } |
| 495 | |
| 496 | err = validateOpReply(reply) |
| 497 | if err != nil { |
| 498 | return wiremessage.Reply{}, err |
| 499 | } |
| 500 | |
| 501 | return reply, nil |
| 502 | } |
| 503 | |
| 504 | func validateOpReply(reply wiremessage.Reply) error { |
| 505 | if int(reply.NumberReturned) != len(reply.Documents) { |
| 506 | return command.NewCommandResponseError(command.ReplyDocumentMismatch, nil) |
| 507 | } |
| 508 | |
| 509 | if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure { |
| 510 | return command.QueryFailureError{ |
| 511 | Message: "query failure", |
| 512 | Response: reply.Documents[0], |
| 513 | } |
| 514 | } |
| 515 | |
| 516 | return nil |
| 517 | } |