blob: 859f797b4a057478d44fa85bfc369e2b588dedcf [file] [log] [blame]
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"context"
"fmt"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/bson/bsontype"
"github.com/mongodb/mongo-go-driver/bson/primitive"
"github.com/mongodb/mongo-go-driver/mongo/readconcern"
"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
"github.com/mongodb/mongo-go-driver/x/bsonx"
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
"github.com/mongodb/mongo-go-driver/x/network/description"
"github.com/mongodb/mongo-go-driver/x/network/result"
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
)
// WriteBatch represents a single batch for a write operation.
type WriteBatch struct {
*Write
numDocs int
}
// DecodeError attempts to decode the wiremessage as an error
func DecodeError(wm wiremessage.WireMessage) error {
var rdr bson.Raw
switch msg := wm.(type) {
case wiremessage.Msg:
for _, section := range msg.Sections {
switch converted := section.(type) {
case wiremessage.SectionBody:
rdr = converted.Document
}
}
case wiremessage.Reply:
if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
return nil
}
rdr = msg.Documents[0]
}
err := rdr.Validate()
if err != nil {
return nil
}
extractedError := extractError(rdr)
// If parsed successfully return the error
if _, ok := extractedError.(Error); ok {
return err
}
return nil
}
// helper method to extract an error from a reader if there is one; first returned item is the
// error if it exists, the second holds parsing errors
func extractError(rdr bson.Raw) error {
var errmsg, codeName string
var code int32
var labels []string
elems, err := rdr.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bson.TypeInt32:
if elem.Value().Int32() == 1 {
return nil
}
case bson.TypeInt64:
if elem.Value().Int64() == 1 {
return nil
}
case bson.TypeDouble:
if elem.Value().Double() == 1 {
return nil
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
elems, err := arr.Elements()
if err != nil {
continue
}
for _, elem := range elems {
if str, ok := elem.Value().StringValueOK(); ok {
labels = append(labels, str)
}
}
}
}
}
if errmsg == "" {
errmsg = "command failed"
}
return Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
}
}
func responseClusterTime(response bson.Raw) bson.Raw {
clusterTime, err := response.LookupErr("$clusterTime")
if err != nil {
// $clusterTime not included by the server
return nil
}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
doc = append(doc, clusterTime.Value...)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc
}
func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
clusterTime := responseClusterTime(response)
if clusterTime == nil {
return nil
}
if sess != nil {
err := sess.AdvanceClusterTime(clusterTime)
if err != nil {
return err
}
}
if clock != nil {
clock.AdvanceClusterTime(clusterTime)
}
return nil
}
func updateOperationTime(sess *session.Client, response bson.Raw) error {
if sess == nil {
return nil
}
opTimeElem, err := response.LookupErr("operationTime")
if err != nil {
// operationTime not included by the server
return nil
}
t, i := opTimeElem.Timestamp()
return sess.AdvanceOperationTime(&primitive.Timestamp{
T: t,
I: i,
})
}
func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
if cmd == nil {
return bson.Raw{5, 0, 0, 0, 0}, nil
}
return cmd.MarshalBSON()
}
// adds session related fields to a BSON doc representing a command
func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
return cmd, nil
}
if client.Terminated {
return cmd, session.ErrSessionEnded
}
if _, err := cmd.LookupElementErr("lsid"); err != nil {
cmd = cmd.Delete("lsid")
}
cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})
if client.TransactionRunning() ||
client.RetryingCommit {
cmd = addTransaction(cmd, client)
}
client.ApplyCommand() // advance the state machine based on a command executing
return cmd, nil
}
// if in a transaction, add the transaction fields
func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
if client.TransactionStarting() {
// When starting transaction, always transition to the next state, even on error
cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
}
return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
}
func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
return cmd
}
var clusterTime bson.Raw
if clock != nil {
clusterTime = clock.GetClusterTime()
}
if sess != nil {
if clusterTime == nil {
clusterTime = sess.ClusterTime
} else {
clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
}
}
if clusterTime == nil {
return cmd
}
d, err := bsonx.ReadDoc(clusterTime)
if err != nil {
return cmd // broken clusterTime
}
cmd = cmd.Delete("$clusterTime")
return append(cmd, d...)
}
// add a read concern to a BSON doc representing a command
func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
// Starting transaction's read concern overrides all others
if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
rc = sess.CurrentRc
}
// start transaction must append afterclustertime IF causally consistent and operation time exists
if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
rc = readconcern.New()
}
if rc == nil {
return cmd, nil
}
t, data, err := rc.MarshalBSONValue()
if err != nil {
return cmd, err
}
var rcDoc bsonx.Doc
err = rcDoc.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
}
cmd = cmd.Delete("readConcern")
if len(rcDoc) != 0 {
cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
}
return cmd, nil
}
// add a write concern to a BSON doc representing a command
func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
if wc == nil {
return cmd, nil
}
t, data, err := wc.MarshalBSONValue()
if err != nil {
if err == writeconcern.ErrEmptyWriteConcern {
return cmd, nil
}
return cmd, err
}
var xval bsonx.Val
err = xval.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
// delete if doc already has write concern
cmd = cmd.Delete("writeConcern")
return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
}
// Get the error labels from a command response
func getErrorLabels(rdr *bson.Raw) ([]string, error) {
var labels []string
labelsElem, err := rdr.LookupErr("errorLabels")
if err != bsoncore.ErrElementNotFound {
return nil, err
}
if labelsElem.Type == bsontype.Array {
labelsIt, err := labelsElem.Array().Elements()
if err != nil {
return nil, err
}
for _, elem := range labelsIt {
labels = append(labels, elem.Value().StringValue())
}
}
return labels, nil
}
// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
// as a Section 1 payload in OP_MSG
func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
var array bsonx.Arr
var id string
keys := []string{"documents", "updates", "deletes"}
for _, key := range keys {
val, err := cmd.LookupErr(key)
if err != nil {
continue
}
array = val.Array()
cmd = cmd.Delete(key)
id = key
break
}
return cmd, array, id
}
// Add the $db and $readPreference keys to the command
// If the command has no read preference, pass nil for rpDoc
func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
if rpDoc != nil {
cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
}
return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
}
func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
docSequence := wiremessage.SectionDocumentSequence{
PayloadType: wiremessage.DocumentSequence,
Identifier: identifier,
Documents: make([]bson.Raw, 0, len(arr)),
}
for _, val := range arr {
d, _ := val.Document().MarshalBSON()
docSequence.Documents = append(docSequence.Documents, d)
}
docSequence.Size = int32(docSequence.PayloadLen())
return docSequence, nil
}
func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
batches := [][]bsonx.Doc{}
if targetBatchSize > reservedCommandBufferBytes {
targetBatchSize -= reservedCommandBufferBytes
}
if maxCount <= 0 {
maxCount = 1
}
startAt := 0
splitInserts:
for {
size := 0
batch := []bsonx.Doc{}
assembleBatch:
for idx := startAt; idx < len(docs); idx++ {
raw, _ := docs[idx].MarshalBSON()
if len(raw) > targetBatchSize {
return nil, ErrDocumentTooLarge
}
if size+len(raw) > targetBatchSize {
break assembleBatch
}
size += len(raw)
batch = append(batch, docs[idx])
startAt++
if len(batch) == maxCount {
break assembleBatch
}
}
batches = append(batches, batch)
if startAt == len(docs) {
break splitInserts
}
}
return batches, nil
}
func encodeBatch(
docs []bsonx.Doc,
opts []bsonx.Elem,
cmdKind WriteCommandKind,
collName string,
) (bsonx.Doc, error) {
var cmdName string
var docString string
switch cmdKind {
case InsertCommand:
cmdName = "insert"
docString = "documents"
case UpdateCommand:
cmdName = "update"
docString = "updates"
case DeleteCommand:
cmdName = "delete"
docString = "deletes"
}
cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}
vals := make(bsonx.Arr, 0, len(docs))
for _, doc := range docs {
vals = append(vals, bsonx.Document(doc))
}
cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
cmd = append(cmd, opts...)
return cmd, nil
}
// converts batches of Write Commands to wire messages
func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
wms := make([]wiremessage.WireMessage, len(batches))
for _, cmd := range batches {
wm, err := cmd.Encode(desc)
if err != nil {
return nil, err
}
wms = append(wms, wm)
}
return wms, nil
}
// Roundtrips the write batches, returning the result structs (as interface),
// the write batches that weren't round tripped and any errors
func roundTripBatches(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
batches []*WriteBatch,
continueOnError bool,
sess *session.Client,
cmdKind WriteCommandKind,
) (interface{}, []*WriteBatch, error) {
var res interface{}
var upsertIndex int64 // the operation index for the upserted IDs map
// hold onto txnNumber, reset it when loop exits to ensure reuse of same
// transaction number if retry is needed
var txnNumber int64
if sess != nil && sess.RetryWrite {
txnNumber = sess.TxnNumber
}
for j, cmd := range batches {
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber + int64(j)
}
return res, batches, err
}
// TODO can probably DRY up this code
switch cmdKind {
case InsertCommand:
if res == nil {
res = result.Insert{}
}
conv, _ := res.(result.Insert)
insertCmd := &Insert{}
r, err := insertCmd.decode(desc, rdr).Result()
if err != nil {
return res, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
case UpdateCommand:
if res == nil {
res = result.Update{}
}
conv, _ := res.(result.Update)
updateCmd := &Update{}
r, err := updateCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.MatchedCount += r.MatchedCount
conv.ModifiedCount += r.ModifiedCount
for _, upsert := range r.Upserted {
conv.Upserted = append(conv.Upserted, result.Upsert{
Index: upsert.Index + upsertIndex,
ID: upsert.ID,
})
}
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
upsertIndex += int64(cmd.numDocs)
case DeleteCommand:
if res == nil {
res = result.Delete{}
}
conv, _ := res.(result.Delete)
deleteCmd := &Delete{}
r, err := deleteCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
}
// Increment txnNumber for each batch
if sess != nil && sess.RetryWrite {
sess.IncrementTxnNumber()
batches = batches[1:] // if batch encoded successfully, remove it from the slice
}
}
if sess != nil && sess.RetryWrite {
// if retryable write succeeded, transaction number will be incremented one extra time,
// so we decrement it here
sess.TxnNumber--
}
return res, batches, nil
}
// get the firstBatch, cursor ID, and namespace from a bson.Raw
func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
cur, err := result.LookupErr("cursor")
if err != nil {
return nil, Namespace{}, 0, err
}
if cur.Type != bson.TypeEmbeddedDocument {
return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
}
elems, err := cur.Document().Elements()
if err != nil {
return nil, Namespace{}, 0, err
}
var ok bool
var arr bson.Raw
var namespace Namespace
var cursorID int64
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok = elem.Value().ArrayOK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
}
if err != nil {
return nil, Namespace{}, 0, err
}
case "ns":
if elem.Value().Type != bson.TypeString {
return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
}
namespace = ParseNamespace(elem.Value().StringValue())
err = namespace.Validate()
if err != nil {
return nil, Namespace{}, 0, err
}
case "id":
cursorID, ok = elem.Value().Int64OK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
}
}
vals, err := arr.Values()
if err != nil {
return nil, Namespace{}, 0, err
}
return vals, namespace, cursorID, nil
}
func getBatchSize(opts []bsonx.Elem) int32 {
for _, opt := range opts {
if opt.Key == "batchSize" {
return opt.Value.Int32()
}
}
return 0
}
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
// WriteCommandKind is the type of command represented by a Write
type WriteCommandKind int8
// These constants represent the valid types of write commands.
const (
InsertCommand WriteCommandKind = iota
UpdateCommand
DeleteCommand
)