blob: 859f797b4a057478d44fa85bfc369e2b588dedcf [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "errors"
11
12 "context"
13
14 "fmt"
15
16 "github.com/mongodb/mongo-go-driver/bson"
17 "github.com/mongodb/mongo-go-driver/bson/bsontype"
18 "github.com/mongodb/mongo-go-driver/bson/primitive"
19 "github.com/mongodb/mongo-go-driver/mongo/readconcern"
20 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
21 "github.com/mongodb/mongo-go-driver/x/bsonx"
22 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
23 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
24 "github.com/mongodb/mongo-go-driver/x/network/description"
25 "github.com/mongodb/mongo-go-driver/x/network/result"
26 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
27)
28
29// WriteBatch represents a single batch for a write operation.
30type WriteBatch struct {
31 *Write
32 numDocs int
33}
34
35// DecodeError attempts to decode the wiremessage as an error
36func DecodeError(wm wiremessage.WireMessage) error {
37 var rdr bson.Raw
38 switch msg := wm.(type) {
39 case wiremessage.Msg:
40 for _, section := range msg.Sections {
41 switch converted := section.(type) {
42 case wiremessage.SectionBody:
43 rdr = converted.Document
44 }
45 }
46 case wiremessage.Reply:
47 if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
48 return nil
49 }
50 rdr = msg.Documents[0]
51 }
52
53 err := rdr.Validate()
54 if err != nil {
55 return nil
56 }
57
58 extractedError := extractError(rdr)
59
60 // If parsed successfully return the error
61 if _, ok := extractedError.(Error); ok {
62 return err
63 }
64
65 return nil
66}
67
68// helper method to extract an error from a reader if there is one; first returned item is the
69// error if it exists, the second holds parsing errors
70func extractError(rdr bson.Raw) error {
71 var errmsg, codeName string
72 var code int32
73 var labels []string
74 elems, err := rdr.Elements()
75 if err != nil {
76 return err
77 }
78
79 for _, elem := range elems {
80 switch elem.Key() {
81 case "ok":
82 switch elem.Value().Type {
83 case bson.TypeInt32:
84 if elem.Value().Int32() == 1 {
85 return nil
86 }
87 case bson.TypeInt64:
88 if elem.Value().Int64() == 1 {
89 return nil
90 }
91 case bson.TypeDouble:
92 if elem.Value().Double() == 1 {
93 return nil
94 }
95 }
96 case "errmsg":
97 if str, okay := elem.Value().StringValueOK(); okay {
98 errmsg = str
99 }
100 case "codeName":
101 if str, okay := elem.Value().StringValueOK(); okay {
102 codeName = str
103 }
104 case "code":
105 if c, okay := elem.Value().Int32OK(); okay {
106 code = c
107 }
108 case "errorLabels":
109 if arr, okay := elem.Value().ArrayOK(); okay {
110 elems, err := arr.Elements()
111 if err != nil {
112 continue
113 }
114 for _, elem := range elems {
115 if str, ok := elem.Value().StringValueOK(); ok {
116 labels = append(labels, str)
117 }
118 }
119
120 }
121 }
122 }
123
124 if errmsg == "" {
125 errmsg = "command failed"
126 }
127
128 return Error{
129 Code: code,
130 Message: errmsg,
131 Name: codeName,
132 Labels: labels,
133 }
134}
135
136func responseClusterTime(response bson.Raw) bson.Raw {
137 clusterTime, err := response.LookupErr("$clusterTime")
138 if err != nil {
139 // $clusterTime not included by the server
140 return nil
141 }
142 idx, doc := bsoncore.AppendDocumentStart(nil)
143 doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
144 doc = append(doc, clusterTime.Value...)
145 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
146 return doc
147}
148
149func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
150 clusterTime := responseClusterTime(response)
151 if clusterTime == nil {
152 return nil
153 }
154
155 if sess != nil {
156 err := sess.AdvanceClusterTime(clusterTime)
157 if err != nil {
158 return err
159 }
160 }
161
162 if clock != nil {
163 clock.AdvanceClusterTime(clusterTime)
164 }
165
166 return nil
167}
168
169func updateOperationTime(sess *session.Client, response bson.Raw) error {
170 if sess == nil {
171 return nil
172 }
173
174 opTimeElem, err := response.LookupErr("operationTime")
175 if err != nil {
176 // operationTime not included by the server
177 return nil
178 }
179
180 t, i := opTimeElem.Timestamp()
181 return sess.AdvanceOperationTime(&primitive.Timestamp{
182 T: t,
183 I: i,
184 })
185}
186
187func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
188 if cmd == nil {
189 return bson.Raw{5, 0, 0, 0, 0}, nil
190 }
191
192 return cmd.MarshalBSON()
193}
194
195// adds session related fields to a BSON doc representing a command
196func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
197 if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
198 return cmd, nil
199 }
200
201 if client.Terminated {
202 return cmd, session.ErrSessionEnded
203 }
204
205 if _, err := cmd.LookupElementErr("lsid"); err != nil {
206 cmd = cmd.Delete("lsid")
207 }
208
209 cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})
210
211 if client.TransactionRunning() ||
212 client.RetryingCommit {
213 cmd = addTransaction(cmd, client)
214 }
215
216 client.ApplyCommand() // advance the state machine based on a command executing
217
218 return cmd, nil
219}
220
221// if in a transaction, add the transaction fields
222func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
223 cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
224 if client.TransactionStarting() {
225 // When starting transaction, always transition to the next state, even on error
226 cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
227 }
228 return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
229}
230
231func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
232 if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
233 return cmd
234 }
235
236 var clusterTime bson.Raw
237 if clock != nil {
238 clusterTime = clock.GetClusterTime()
239 }
240
241 if sess != nil {
242 if clusterTime == nil {
243 clusterTime = sess.ClusterTime
244 } else {
245 clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
246 }
247 }
248
249 if clusterTime == nil {
250 return cmd
251 }
252
253 d, err := bsonx.ReadDoc(clusterTime)
254 if err != nil {
255 return cmd // broken clusterTime
256 }
257
258 cmd = cmd.Delete("$clusterTime")
259
260 return append(cmd, d...)
261}
262
263// add a read concern to a BSON doc representing a command
264func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
265 // Starting transaction's read concern overrides all others
266 if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
267 rc = sess.CurrentRc
268 }
269
270 // start transaction must append afterclustertime IF causally consistent and operation time exists
271 if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
272 rc = readconcern.New()
273 }
274
275 if rc == nil {
276 return cmd, nil
277 }
278
279 t, data, err := rc.MarshalBSONValue()
280 if err != nil {
281 return cmd, err
282 }
283
284 var rcDoc bsonx.Doc
285 err = rcDoc.UnmarshalBSONValue(t, data)
286 if err != nil {
287 return cmd, err
288 }
289 if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
290 rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
291 }
292
293 cmd = cmd.Delete("readConcern")
294
295 if len(rcDoc) != 0 {
296 cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
297 }
298 return cmd, nil
299}
300
301// add a write concern to a BSON doc representing a command
302func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
303 if wc == nil {
304 return cmd, nil
305 }
306
307 t, data, err := wc.MarshalBSONValue()
308 if err != nil {
309 if err == writeconcern.ErrEmptyWriteConcern {
310 return cmd, nil
311 }
312 return cmd, err
313 }
314
315 var xval bsonx.Val
316 err = xval.UnmarshalBSONValue(t, data)
317 if err != nil {
318 return cmd, err
319 }
320
321 // delete if doc already has write concern
322 cmd = cmd.Delete("writeConcern")
323
324 return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
325}
326
327// Get the error labels from a command response
328func getErrorLabels(rdr *bson.Raw) ([]string, error) {
329 var labels []string
330 labelsElem, err := rdr.LookupErr("errorLabels")
331 if err != bsoncore.ErrElementNotFound {
332 return nil, err
333 }
334 if labelsElem.Type == bsontype.Array {
335 labelsIt, err := labelsElem.Array().Elements()
336 if err != nil {
337 return nil, err
338 }
339 for _, elem := range labelsIt {
340 labels = append(labels, elem.Value().StringValue())
341 }
342 }
343 return labels, nil
344}
345
346// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
347// as a Section 1 payload in OP_MSG
348func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
349 var array bsonx.Arr
350 var id string
351
352 keys := []string{"documents", "updates", "deletes"}
353
354 for _, key := range keys {
355 val, err := cmd.LookupErr(key)
356 if err != nil {
357 continue
358 }
359
360 array = val.Array()
361 cmd = cmd.Delete(key)
362 id = key
363 break
364 }
365
366 return cmd, array, id
367}
368
369// Add the $db and $readPreference keys to the command
370// If the command has no read preference, pass nil for rpDoc
371func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
372 cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
373 if rpDoc != nil {
374 cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
375 }
376
377 return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
378}
379
380func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
381 docSequence := wiremessage.SectionDocumentSequence{
382 PayloadType: wiremessage.DocumentSequence,
383 Identifier: identifier,
384 Documents: make([]bson.Raw, 0, len(arr)),
385 }
386
387 for _, val := range arr {
388 d, _ := val.Document().MarshalBSON()
389 docSequence.Documents = append(docSequence.Documents, d)
390 }
391
392 docSequence.Size = int32(docSequence.PayloadLen())
393 return docSequence, nil
394}
395
396func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
397 batches := [][]bsonx.Doc{}
398
399 if targetBatchSize > reservedCommandBufferBytes {
400 targetBatchSize -= reservedCommandBufferBytes
401 }
402
403 if maxCount <= 0 {
404 maxCount = 1
405 }
406
407 startAt := 0
408splitInserts:
409 for {
410 size := 0
411 batch := []bsonx.Doc{}
412 assembleBatch:
413 for idx := startAt; idx < len(docs); idx++ {
414 raw, _ := docs[idx].MarshalBSON()
415
416 if len(raw) > targetBatchSize {
417 return nil, ErrDocumentTooLarge
418 }
419 if size+len(raw) > targetBatchSize {
420 break assembleBatch
421 }
422
423 size += len(raw)
424 batch = append(batch, docs[idx])
425 startAt++
426 if len(batch) == maxCount {
427 break assembleBatch
428 }
429 }
430 batches = append(batches, batch)
431 if startAt == len(docs) {
432 break splitInserts
433 }
434 }
435
436 return batches, nil
437}
438
439func encodeBatch(
440 docs []bsonx.Doc,
441 opts []bsonx.Elem,
442 cmdKind WriteCommandKind,
443 collName string,
444) (bsonx.Doc, error) {
445 var cmdName string
446 var docString string
447
448 switch cmdKind {
449 case InsertCommand:
450 cmdName = "insert"
451 docString = "documents"
452 case UpdateCommand:
453 cmdName = "update"
454 docString = "updates"
455 case DeleteCommand:
456 cmdName = "delete"
457 docString = "deletes"
458 }
459
460 cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}
461
462 vals := make(bsonx.Arr, 0, len(docs))
463 for _, doc := range docs {
464 vals = append(vals, bsonx.Document(doc))
465 }
466 cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
467 cmd = append(cmd, opts...)
468
469 return cmd, nil
470}
471
472// converts batches of Write Commands to wire messages
473func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
474 wms := make([]wiremessage.WireMessage, len(batches))
475 for _, cmd := range batches {
476 wm, err := cmd.Encode(desc)
477 if err != nil {
478 return nil, err
479 }
480
481 wms = append(wms, wm)
482 }
483
484 return wms, nil
485}
486
487// Roundtrips the write batches, returning the result structs (as interface),
488// the write batches that weren't round tripped and any errors
489func roundTripBatches(
490 ctx context.Context,
491 desc description.SelectedServer,
492 rw wiremessage.ReadWriter,
493 batches []*WriteBatch,
494 continueOnError bool,
495 sess *session.Client,
496 cmdKind WriteCommandKind,
497) (interface{}, []*WriteBatch, error) {
498 var res interface{}
499 var upsertIndex int64 // the operation index for the upserted IDs map
500
501 // hold onto txnNumber, reset it when loop exits to ensure reuse of same
502 // transaction number if retry is needed
503 var txnNumber int64
504 if sess != nil && sess.RetryWrite {
505 txnNumber = sess.TxnNumber
506 }
507 for j, cmd := range batches {
508 rdr, err := cmd.RoundTrip(ctx, desc, rw)
509 if err != nil {
510 if sess != nil && sess.RetryWrite {
511 sess.TxnNumber = txnNumber + int64(j)
512 }
513 return res, batches, err
514 }
515
516 // TODO can probably DRY up this code
517 switch cmdKind {
518 case InsertCommand:
519 if res == nil {
520 res = result.Insert{}
521 }
522
523 conv, _ := res.(result.Insert)
524 insertCmd := &Insert{}
525 r, err := insertCmd.decode(desc, rdr).Result()
526 if err != nil {
527 return res, batches, err
528 }
529
530 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
531
532 if r.WriteConcernError != nil {
533 conv.WriteConcernError = r.WriteConcernError
534 if sess != nil && sess.RetryWrite {
535 sess.TxnNumber = txnNumber
536 return conv, batches, nil // report writeconcernerror for retry
537 }
538 }
539
540 conv.N += r.N
541
542 if !continueOnError && len(conv.WriteErrors) > 0 {
543 return conv, batches, nil
544 }
545
546 res = conv
547 case UpdateCommand:
548 if res == nil {
549 res = result.Update{}
550 }
551
552 conv, _ := res.(result.Update)
553 updateCmd := &Update{}
554 r, err := updateCmd.decode(desc, rdr).Result()
555 if err != nil {
556 return conv, batches, err
557 }
558
559 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
560
561 if r.WriteConcernError != nil {
562 conv.WriteConcernError = r.WriteConcernError
563 if sess != nil && sess.RetryWrite {
564 sess.TxnNumber = txnNumber
565 return conv, batches, nil // report writeconcernerror for retry
566 }
567 }
568
569 conv.MatchedCount += r.MatchedCount
570 conv.ModifiedCount += r.ModifiedCount
571 for _, upsert := range r.Upserted {
572 conv.Upserted = append(conv.Upserted, result.Upsert{
573 Index: upsert.Index + upsertIndex,
574 ID: upsert.ID,
575 })
576 }
577
578 if !continueOnError && len(conv.WriteErrors) > 0 {
579 return conv, batches, nil
580 }
581
582 res = conv
583 upsertIndex += int64(cmd.numDocs)
584 case DeleteCommand:
585 if res == nil {
586 res = result.Delete{}
587 }
588
589 conv, _ := res.(result.Delete)
590 deleteCmd := &Delete{}
591 r, err := deleteCmd.decode(desc, rdr).Result()
592 if err != nil {
593 return conv, batches, err
594 }
595
596 conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
597
598 if r.WriteConcernError != nil {
599 conv.WriteConcernError = r.WriteConcernError
600 if sess != nil && sess.RetryWrite {
601 sess.TxnNumber = txnNumber
602 return conv, batches, nil // report writeconcernerror for retry
603 }
604 }
605
606 conv.N += r.N
607
608 if !continueOnError && len(conv.WriteErrors) > 0 {
609 return conv, batches, nil
610 }
611
612 res = conv
613 }
614
615 // Increment txnNumber for each batch
616 if sess != nil && sess.RetryWrite {
617 sess.IncrementTxnNumber()
618 batches = batches[1:] // if batch encoded successfully, remove it from the slice
619 }
620 }
621
622 if sess != nil && sess.RetryWrite {
623 // if retryable write succeeded, transaction number will be incremented one extra time,
624 // so we decrement it here
625 sess.TxnNumber--
626 }
627
628 return res, batches, nil
629}
630
631// get the firstBatch, cursor ID, and namespace from a bson.Raw
632func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
633 cur, err := result.LookupErr("cursor")
634 if err != nil {
635 return nil, Namespace{}, 0, err
636 }
637 if cur.Type != bson.TypeEmbeddedDocument {
638 return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
639 }
640
641 elems, err := cur.Document().Elements()
642 if err != nil {
643 return nil, Namespace{}, 0, err
644 }
645
646 var ok bool
647 var arr bson.Raw
648 var namespace Namespace
649 var cursorID int64
650
651 for _, elem := range elems {
652 switch elem.Key() {
653 case "firstBatch":
654 arr, ok = elem.Value().ArrayOK()
655 if !ok {
656 return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
657 }
658 if err != nil {
659 return nil, Namespace{}, 0, err
660 }
661 case "ns":
662 if elem.Value().Type != bson.TypeString {
663 return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
664 }
665 namespace = ParseNamespace(elem.Value().StringValue())
666 err = namespace.Validate()
667 if err != nil {
668 return nil, Namespace{}, 0, err
669 }
670 case "id":
671 cursorID, ok = elem.Value().Int64OK()
672 if !ok {
673 return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
674 }
675 }
676 }
677
678 vals, err := arr.Values()
679 if err != nil {
680 return nil, Namespace{}, 0, err
681 }
682
683 return vals, namespace, cursorID, nil
684}
685
686func getBatchSize(opts []bsonx.Elem) int32 {
687 for _, opt := range opts {
688 if opt.Key == "batchSize" {
689 return opt.Value.Int32()
690 }
691 }
692
693 return 0
694}
695
696// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
697// write concern.
698var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
699
700// WriteCommandKind is the type of command represented by a Write
701type WriteCommandKind int8
702
703// These constants represent the valid types of write commands.
704const (
705 InsertCommand WriteCommandKind = iota
706 UpdateCommand
707 DeleteCommand
708)