blob: 4d461d58ea100084e3f59708fbbbe2ccc661ac4c [file] [log] [blame]
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"github.com/mongodb/mongo-go-driver/bson/bsoncodec"
"github.com/mongodb/mongo-go-driver/mongo/options"
"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
"github.com/mongodb/mongo-go-driver/x/bsonx"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
"github.com/mongodb/mongo-go-driver/x/network/command"
"github.com/mongodb/mongo-go-driver/x/network/description"
"github.com/mongodb/mongo-go-driver/x/network/result"
)
// BulkWriteError is an error from one operation in a bulk write.
type BulkWriteError struct {
result.WriteError
Model WriteModel
}
// BulkWriteException is a collection of errors returned by a bulk write operation.
type BulkWriteException struct {
WriteConcernError *result.WriteConcernError
WriteErrors []BulkWriteError
}
func (BulkWriteException) Error() string {
return ""
}
type bulkWriteBatch struct {
models []WriteModel
canRetry bool
}
// BulkWrite handles the full dispatch cycle for a bulk write operation.
func BulkWrite(
ctx context.Context,
ns command.Namespace,
models []WriteModel,
topo *topology.Topology,
selector description.ServerSelector,
clientID uuid.UUID,
pool *session.Pool,
retryWrite bool,
sess *session.Client,
writeConcern *writeconcern.WriteConcern,
clock *session.ClusterClock,
registry *bsoncodec.Registry,
opts ...*options.BulkWriteOptions,
) (result.BulkWrite, error) {
ss, err := topo.SelectServer(ctx, selector)
if err != nil {
return result.BulkWrite{}, err
}
err = verifyOptions(models, ss)
if err != nil {
return result.BulkWrite{}, err
}
// If no explicit session and deployment supports sessions, start implicit session.
if sess == nil && topo.SupportsSessions() {
sess, err = session.NewClientSession(pool, clientID, session.Implicit)
if err != nil {
return result.BulkWrite{}, err
}
defer sess.EndSession()
}
bwOpts := options.MergeBulkWriteOptions(opts...)
ordered := *bwOpts.Ordered
batches := createBatches(models, ordered)
bwRes := result.BulkWrite{
UpsertedIDs: make(map[int64]interface{}),
}
bwErr := BulkWriteException{
WriteErrors: make([]BulkWriteError, 0),
}
var opIndex int64 // the operation index for the upsertedIDs map
continueOnError := !ordered
for _, batch := range batches {
if len(batch.models) == 0 {
continue
}
batchRes, batchErr, err := runBatch(ctx, ns, topo, selector, ss, sess, clock, writeConcern, retryWrite,
bwOpts.BypassDocumentValidation, continueOnError, batch, registry)
mergeResults(&bwRes, batchRes, opIndex)
bwErr.WriteConcernError = batchErr.WriteConcernError
for i := range batchErr.WriteErrors {
batchErr.WriteErrors[i].Index = batchErr.WriteErrors[i].Index + int(opIndex)
}
bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
if !continueOnError && (err != nil || len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil) {
if err != nil {
return result.BulkWrite{}, err
}
return result.BulkWrite{}, bwErr
}
opIndex += int64(len(batch.models))
}
bwRes.MatchedCount -= bwRes.UpsertedCount
return bwRes, nil
}
func runBatch(
ctx context.Context,
ns command.Namespace,
topo *topology.Topology,
selector description.ServerSelector,
ss *topology.SelectedServer,
sess *session.Client,
clock *session.ClusterClock,
wc *writeconcern.WriteConcern,
retryWrite bool,
bypassDocValidation *bool,
continueOnError bool,
batch bulkWriteBatch,
registry *bsoncodec.Registry,
) (result.BulkWrite, BulkWriteException, error) {
batchRes := result.BulkWrite{
UpsertedIDs: make(map[int64]interface{}),
}
batchErr := BulkWriteException{}
var writeErrors []result.WriteError
switch batch.models[0].(type) {
case InsertOneModel:
res, err := runInsert(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, bypassDocValidation,
continueOnError, registry)
if err != nil {
return result.BulkWrite{}, BulkWriteException{}, err
}
batchRes.InsertedCount = int64(res.N)
writeErrors = res.WriteErrors
case DeleteOneModel, DeleteManyModel:
res, err := runDelete(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, continueOnError, registry)
if err != nil {
return result.BulkWrite{}, BulkWriteException{}, err
}
batchRes.DeletedCount = int64(res.N)
writeErrors = res.WriteErrors
case ReplaceOneModel, UpdateOneModel, UpdateManyModel:
res, err := runUpdate(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, bypassDocValidation,
continueOnError, registry)
if err != nil {
return result.BulkWrite{}, BulkWriteException{}, err
}
batchRes.MatchedCount = res.MatchedCount
batchRes.ModifiedCount = res.ModifiedCount
batchRes.UpsertedCount = int64(len(res.Upserted))
writeErrors = res.WriteErrors
for _, upsert := range res.Upserted {
batchRes.UpsertedIDs[upsert.Index] = upsert.ID
}
}
batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
for _, we := range writeErrors {
batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
WriteError: we,
Model: batch.models[0],
})
}
return batchRes, batchErr, nil
}
func runInsert(
ctx context.Context,
ns command.Namespace,
topo *topology.Topology,
selector description.ServerSelector,
ss *topology.SelectedServer,
sess *session.Client,
clock *session.ClusterClock,
wc *writeconcern.WriteConcern,
retryWrite bool,
batch bulkWriteBatch,
bypassDocValidation *bool,
continueOnError bool,
registry *bsoncodec.Registry,
) (result.Insert, error) {
docs := make([]bsonx.Doc, len(batch.models))
var i int
for _, model := range batch.models {
converted := model.(InsertOneModel)
doc, err := interfaceToDocument(converted.Document, registry)
if err != nil {
return result.Insert{}, err
}
docs[i] = doc
i++
}
cmd := command.Insert{
ContinueOnError: continueOnError,
NS: ns,
Docs: docs,
Session: sess,
Clock: clock,
WriteConcern: wc,
}
if bypassDocValidation != nil {
cmd.Opts = []bsonx.Elem{{"bypassDocumentValidation", bsonx.Boolean(*bypassDocValidation)}}
}
if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
if cmd.Session != nil {
cmd.Session.RetryWrite = false
}
return insert(ctx, cmd, ss, nil)
}
cmd.Session.RetryWrite = retryWrite
cmd.Session.IncrementTxnNumber()
res, origErr := insert(ctx, cmd, ss, nil)
if shouldRetry(origErr, res.WriteConcernError) {
newServer, err := topo.SelectServer(ctx, selector)
if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
return res, origErr
}
return insert(ctx, cmd, newServer, origErr)
}
return res, origErr
}
func runDelete(
ctx context.Context,
ns command.Namespace,
topo *topology.Topology,
selector description.ServerSelector,
ss *topology.SelectedServer,
sess *session.Client,
clock *session.ClusterClock,
wc *writeconcern.WriteConcern,
retryWrite bool,
batch bulkWriteBatch,
continueOnError bool,
registry *bsoncodec.Registry,
) (result.Delete, error) {
docs := make([]bsonx.Doc, len(batch.models))
var i int
for _, model := range batch.models {
var doc bsonx.Doc
var err error
if dom, ok := model.(DeleteOneModel); ok {
doc, err = createDeleteDoc(dom.Filter, dom.Collation, false, registry)
} else if dmm, ok := model.(DeleteManyModel); ok {
doc, err = createDeleteDoc(dmm.Filter, dmm.Collation, true, registry)
}
if err != nil {
return result.Delete{}, err
}
docs[i] = doc
i++
}
cmd := command.Delete{
ContinueOnError: continueOnError,
NS: ns,
Deletes: docs,
Session: sess,
Clock: clock,
WriteConcern: wc,
}
if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
if cmd.Session != nil {
cmd.Session.RetryWrite = false
}
return delete(ctx, cmd, ss, nil)
}
cmd.Session.RetryWrite = retryWrite
cmd.Session.IncrementTxnNumber()
res, origErr := delete(ctx, cmd, ss, nil)
if shouldRetry(origErr, res.WriteConcernError) {
newServer, err := topo.SelectServer(ctx, selector)
if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
return res, origErr
}
return delete(ctx, cmd, newServer, origErr)
}
return res, origErr
}
func runUpdate(
ctx context.Context,
ns command.Namespace,
topo *topology.Topology,
selector description.ServerSelector,
ss *topology.SelectedServer,
sess *session.Client,
clock *session.ClusterClock,
wc *writeconcern.WriteConcern,
retryWrite bool,
batch bulkWriteBatch,
bypassDocValidation *bool,
continueOnError bool,
registry *bsoncodec.Registry,
) (result.Update, error) {
docs := make([]bsonx.Doc, len(batch.models))
for i, model := range batch.models {
var doc bsonx.Doc
var err error
if rom, ok := model.(ReplaceOneModel); ok {
doc, err = createUpdateDoc(rom.Filter, rom.Replacement, options.ArrayFilters{}, false, rom.UpdateModel, false,
registry)
} else if uom, ok := model.(UpdateOneModel); ok {
doc, err = createUpdateDoc(uom.Filter, uom.Update, uom.ArrayFilters, uom.ArrayFiltersSet, uom.UpdateModel, false,
registry)
} else if umm, ok := model.(UpdateManyModel); ok {
doc, err = createUpdateDoc(umm.Filter, umm.Update, umm.ArrayFilters, umm.ArrayFiltersSet, umm.UpdateModel, true,
registry)
}
if err != nil {
return result.Update{}, err
}
docs[i] = doc
}
cmd := command.Update{
ContinueOnError: continueOnError,
NS: ns,
Docs: docs,
Session: sess,
Clock: clock,
WriteConcern: wc,
}
if bypassDocValidation != nil {
// TODO this is temporary!
cmd.Opts = []bsonx.Elem{{"bypassDocumentValidation", bsonx.Boolean(*bypassDocValidation)}}
//cmd.Opts = []option.UpdateOptioner{option.OptBypassDocumentValidation(bypassDocValidation)}
}
if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
if cmd.Session != nil {
cmd.Session.RetryWrite = false
}
return update(ctx, cmd, ss, nil)
}
cmd.Session.RetryWrite = retryWrite
cmd.Session.IncrementTxnNumber()
res, origErr := update(ctx, cmd, ss, nil)
if shouldRetry(origErr, res.WriteConcernError) {
newServer, err := topo.SelectServer(ctx, selector)
if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
return res, origErr
}
return update(ctx, cmd, newServer, origErr)
}
return res, origErr
}
func verifyOptions(models []WriteModel, ss *topology.SelectedServer) error {
maxVersion := ss.Description().WireVersion.Max
// 3.4 is wire version 5
// 3.6 is wire version 6
for _, model := range models {
var collationSet bool
var afSet bool // arrayFilters
switch converted := model.(type) {
case DeleteOneModel:
collationSet = converted.Collation != nil
case DeleteManyModel:
collationSet = converted.Collation != nil
case ReplaceOneModel:
collationSet = converted.Collation != nil
case UpdateOneModel:
afSet = converted.ArrayFiltersSet
collationSet = converted.Collation != nil
case UpdateManyModel:
afSet = converted.ArrayFiltersSet
collationSet = converted.Collation != nil
}
if afSet && maxVersion < 6 {
return ErrArrayFilters
}
if collationSet && maxVersion < 5 {
return ErrCollation
}
}
return nil
}
func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
if ordered {
return createOrderedBatches(models)
}
batches := make([]bulkWriteBatch, 3)
var i int
for i = 0; i < 3; i++ {
batches[i].canRetry = true
}
var numBatches int // number of batches used. can't use len(batches) because it's set to 3
insertInd := -1
updateInd := -1
deleteInd := -1
for _, model := range models {
switch converted := model.(type) {
case InsertOneModel:
if insertInd == -1 {
// this is the first InsertOneModel
insertInd = numBatches
numBatches++
}
batches[insertInd].models = append(batches[insertInd].models, model)
case DeleteOneModel, DeleteManyModel:
if deleteInd == -1 {
deleteInd = numBatches
numBatches++
}
batches[deleteInd].models = append(batches[deleteInd].models, model)
if _, ok := converted.(DeleteManyModel); ok {
batches[deleteInd].canRetry = false
}
case ReplaceOneModel, UpdateOneModel, UpdateManyModel:
if updateInd == -1 {
updateInd = numBatches
numBatches++
}
batches[updateInd].models = append(batches[updateInd].models, model)
if _, ok := converted.(UpdateManyModel); ok {
batches[updateInd].canRetry = false
}
}
}
return batches
}
func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
var batches []bulkWriteBatch
var prevKind command.WriteCommandKind = -1
i := -1 // batch index
for _, model := range models {
var createNewBatch bool
var canRetry bool
var newKind command.WriteCommandKind
switch model.(type) {
case InsertOneModel:
createNewBatch = prevKind != command.InsertCommand
canRetry = true
newKind = command.InsertCommand
case DeleteOneModel:
createNewBatch = prevKind != command.DeleteCommand
canRetry = true
newKind = command.DeleteCommand
case DeleteManyModel:
createNewBatch = prevKind != command.DeleteCommand
newKind = command.DeleteCommand
case ReplaceOneModel, UpdateOneModel:
createNewBatch = prevKind != command.UpdateCommand
canRetry = true
newKind = command.UpdateCommand
case UpdateManyModel:
createNewBatch = prevKind != command.UpdateCommand
newKind = command.UpdateCommand
}
if createNewBatch {
batches = append(batches, bulkWriteBatch{
models: []WriteModel{model},
canRetry: canRetry,
})
i++
} else {
batches[i].models = append(batches[i].models, model)
if !canRetry {
batches[i].canRetry = false // don't make it true if it was already false
}
}
prevKind = newKind
}
return batches
}
func shouldRetry(cmdErr error, wcErr *result.WriteConcernError) bool {
if cerr, ok := cmdErr.(command.Error); ok && cerr.Retryable() ||
wcErr != nil && command.IsWriteConcernErrorRetryable(wcErr) {
return true
}
return false
}
func createUpdateDoc(
filter interface{},
update interface{},
arrayFilters options.ArrayFilters,
arrayFiltersSet bool,
updateModel UpdateModel,
multi bool,
registry *bsoncodec.Registry,
) (bsonx.Doc, error) {
f, err := interfaceToDocument(filter, registry)
if err != nil {
return nil, err
}
u, err := interfaceToDocument(update, registry)
if err != nil {
return nil, err
}
doc := bsonx.Doc{
{"q", bsonx.Document(f)},
{"u", bsonx.Document(u)},
{"multi", bsonx.Boolean(multi)},
}
if arrayFiltersSet {
arr, err := arrayFilters.ToArray()
if err != nil {
return nil, err
}
doc = append(doc, bsonx.Elem{"arrayFilters", bsonx.Array(arr)})
}
if updateModel.Collation != nil {
doc = append(doc, bsonx.Elem{"collation", bsonx.Document(updateModel.Collation.ToDocument())})
}
if updateModel.UpsertSet {
doc = append(doc, bsonx.Elem{"upsert", bsonx.Boolean(updateModel.Upsert)})
}
return doc, nil
}
func createDeleteDoc(
filter interface{},
collation *options.Collation,
many bool,
registry *bsoncodec.Registry,
) (bsonx.Doc, error) {
f, err := interfaceToDocument(filter, registry)
if err != nil {
return nil, err
}
var limit int32 = 1
if many {
limit = 0
}
doc := bsonx.Doc{
{"q", bsonx.Document(f)},
{"limit", bsonx.Int32(limit)},
}
if collation != nil {
doc = append(doc, bsonx.Elem{"collation", bsonx.Document(collation.ToDocument())})
}
return doc, nil
}
func mergeResults(aggResult *result.BulkWrite, newResult result.BulkWrite, opIndex int64) {
aggResult.InsertedCount += newResult.InsertedCount
aggResult.MatchedCount += newResult.MatchedCount
aggResult.ModifiedCount += newResult.ModifiedCount
aggResult.DeletedCount += newResult.DeletedCount
aggResult.UpsertedCount += newResult.UpsertedCount
for index, upsertID := range newResult.UpsertedIDs {
aggResult.UpsertedIDs[index+opIndex] = upsertID
}
}