blob: 4d461d58ea100084e3f59708fbbbe2ccc661ac4c [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package driver
8
9import (
10 "context"
11
12 "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
13 "github.com/mongodb/mongo-go-driver/mongo/options"
14 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
15 "github.com/mongodb/mongo-go-driver/x/bsonx"
16 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
17 "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
18 "github.com/mongodb/mongo-go-driver/x/mongo/driver/uuid"
19 "github.com/mongodb/mongo-go-driver/x/network/command"
20 "github.com/mongodb/mongo-go-driver/x/network/description"
21 "github.com/mongodb/mongo-go-driver/x/network/result"
22)
23
24// BulkWriteError is an error from one operation in a bulk write.
25type BulkWriteError struct {
26 result.WriteError
27 Model WriteModel
28}
29
30// BulkWriteException is a collection of errors returned by a bulk write operation.
31type BulkWriteException struct {
32 WriteConcernError *result.WriteConcernError
33 WriteErrors []BulkWriteError
34}
35
36func (BulkWriteException) Error() string {
37 return ""
38}
39
40type bulkWriteBatch struct {
41 models []WriteModel
42 canRetry bool
43}
44
45// BulkWrite handles the full dispatch cycle for a bulk write operation.
46func BulkWrite(
47 ctx context.Context,
48 ns command.Namespace,
49 models []WriteModel,
50 topo *topology.Topology,
51 selector description.ServerSelector,
52 clientID uuid.UUID,
53 pool *session.Pool,
54 retryWrite bool,
55 sess *session.Client,
56 writeConcern *writeconcern.WriteConcern,
57 clock *session.ClusterClock,
58 registry *bsoncodec.Registry,
59 opts ...*options.BulkWriteOptions,
60) (result.BulkWrite, error) {
61 ss, err := topo.SelectServer(ctx, selector)
62 if err != nil {
63 return result.BulkWrite{}, err
64 }
65
66 err = verifyOptions(models, ss)
67 if err != nil {
68 return result.BulkWrite{}, err
69 }
70
71 // If no explicit session and deployment supports sessions, start implicit session.
72 if sess == nil && topo.SupportsSessions() {
73 sess, err = session.NewClientSession(pool, clientID, session.Implicit)
74 if err != nil {
75 return result.BulkWrite{}, err
76 }
77
78 defer sess.EndSession()
79 }
80
81 bwOpts := options.MergeBulkWriteOptions(opts...)
82
83 ordered := *bwOpts.Ordered
84
85 batches := createBatches(models, ordered)
86 bwRes := result.BulkWrite{
87 UpsertedIDs: make(map[int64]interface{}),
88 }
89 bwErr := BulkWriteException{
90 WriteErrors: make([]BulkWriteError, 0),
91 }
92
93 var opIndex int64 // the operation index for the upsertedIDs map
94 continueOnError := !ordered
95 for _, batch := range batches {
96 if len(batch.models) == 0 {
97 continue
98 }
99
100 batchRes, batchErr, err := runBatch(ctx, ns, topo, selector, ss, sess, clock, writeConcern, retryWrite,
101 bwOpts.BypassDocumentValidation, continueOnError, batch, registry)
102
103 mergeResults(&bwRes, batchRes, opIndex)
104 bwErr.WriteConcernError = batchErr.WriteConcernError
105 for i := range batchErr.WriteErrors {
106 batchErr.WriteErrors[i].Index = batchErr.WriteErrors[i].Index + int(opIndex)
107 }
108 bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
109
110 if !continueOnError && (err != nil || len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil) {
111 if err != nil {
112 return result.BulkWrite{}, err
113 }
114
115 return result.BulkWrite{}, bwErr
116 }
117
118 opIndex += int64(len(batch.models))
119 }
120
121 bwRes.MatchedCount -= bwRes.UpsertedCount
122 return bwRes, nil
123}
124
125func runBatch(
126 ctx context.Context,
127 ns command.Namespace,
128 topo *topology.Topology,
129 selector description.ServerSelector,
130 ss *topology.SelectedServer,
131 sess *session.Client,
132 clock *session.ClusterClock,
133 wc *writeconcern.WriteConcern,
134 retryWrite bool,
135 bypassDocValidation *bool,
136 continueOnError bool,
137 batch bulkWriteBatch,
138 registry *bsoncodec.Registry,
139) (result.BulkWrite, BulkWriteException, error) {
140 batchRes := result.BulkWrite{
141 UpsertedIDs: make(map[int64]interface{}),
142 }
143 batchErr := BulkWriteException{}
144
145 var writeErrors []result.WriteError
146 switch batch.models[0].(type) {
147 case InsertOneModel:
148 res, err := runInsert(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, bypassDocValidation,
149 continueOnError, registry)
150 if err != nil {
151 return result.BulkWrite{}, BulkWriteException{}, err
152 }
153
154 batchRes.InsertedCount = int64(res.N)
155 writeErrors = res.WriteErrors
156 case DeleteOneModel, DeleteManyModel:
157 res, err := runDelete(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, continueOnError, registry)
158 if err != nil {
159 return result.BulkWrite{}, BulkWriteException{}, err
160 }
161
162 batchRes.DeletedCount = int64(res.N)
163 writeErrors = res.WriteErrors
164 case ReplaceOneModel, UpdateOneModel, UpdateManyModel:
165 res, err := runUpdate(ctx, ns, topo, selector, ss, sess, clock, wc, retryWrite, batch, bypassDocValidation,
166 continueOnError, registry)
167 if err != nil {
168 return result.BulkWrite{}, BulkWriteException{}, err
169 }
170
171 batchRes.MatchedCount = res.MatchedCount
172 batchRes.ModifiedCount = res.ModifiedCount
173 batchRes.UpsertedCount = int64(len(res.Upserted))
174 writeErrors = res.WriteErrors
175 for _, upsert := range res.Upserted {
176 batchRes.UpsertedIDs[upsert.Index] = upsert.ID
177 }
178 }
179
180 batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
181 for _, we := range writeErrors {
182 batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
183 WriteError: we,
184 Model: batch.models[0],
185 })
186 }
187
188 return batchRes, batchErr, nil
189}
190
191func runInsert(
192 ctx context.Context,
193 ns command.Namespace,
194 topo *topology.Topology,
195 selector description.ServerSelector,
196 ss *topology.SelectedServer,
197 sess *session.Client,
198 clock *session.ClusterClock,
199 wc *writeconcern.WriteConcern,
200 retryWrite bool,
201 batch bulkWriteBatch,
202 bypassDocValidation *bool,
203 continueOnError bool,
204 registry *bsoncodec.Registry,
205) (result.Insert, error) {
206 docs := make([]bsonx.Doc, len(batch.models))
207 var i int
208 for _, model := range batch.models {
209 converted := model.(InsertOneModel)
210 doc, err := interfaceToDocument(converted.Document, registry)
211 if err != nil {
212 return result.Insert{}, err
213 }
214
215 docs[i] = doc
216 i++
217 }
218
219 cmd := command.Insert{
220 ContinueOnError: continueOnError,
221 NS: ns,
222 Docs: docs,
223 Session: sess,
224 Clock: clock,
225 WriteConcern: wc,
226 }
227
228 if bypassDocValidation != nil {
229 cmd.Opts = []bsonx.Elem{{"bypassDocumentValidation", bsonx.Boolean(*bypassDocValidation)}}
230 }
231
232 if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
233 if cmd.Session != nil {
234 cmd.Session.RetryWrite = false
235 }
236 return insert(ctx, cmd, ss, nil)
237 }
238
239 cmd.Session.RetryWrite = retryWrite
240 cmd.Session.IncrementTxnNumber()
241
242 res, origErr := insert(ctx, cmd, ss, nil)
243 if shouldRetry(origErr, res.WriteConcernError) {
244 newServer, err := topo.SelectServer(ctx, selector)
245 if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
246 return res, origErr
247 }
248
249 return insert(ctx, cmd, newServer, origErr)
250 }
251
252 return res, origErr
253}
254
255func runDelete(
256 ctx context.Context,
257 ns command.Namespace,
258 topo *topology.Topology,
259 selector description.ServerSelector,
260 ss *topology.SelectedServer,
261 sess *session.Client,
262 clock *session.ClusterClock,
263 wc *writeconcern.WriteConcern,
264 retryWrite bool,
265 batch bulkWriteBatch,
266 continueOnError bool,
267 registry *bsoncodec.Registry,
268) (result.Delete, error) {
269 docs := make([]bsonx.Doc, len(batch.models))
270 var i int
271
272 for _, model := range batch.models {
273 var doc bsonx.Doc
274 var err error
275
276 if dom, ok := model.(DeleteOneModel); ok {
277 doc, err = createDeleteDoc(dom.Filter, dom.Collation, false, registry)
278 } else if dmm, ok := model.(DeleteManyModel); ok {
279 doc, err = createDeleteDoc(dmm.Filter, dmm.Collation, true, registry)
280 }
281
282 if err != nil {
283 return result.Delete{}, err
284 }
285
286 docs[i] = doc
287 i++
288 }
289
290 cmd := command.Delete{
291 ContinueOnError: continueOnError,
292 NS: ns,
293 Deletes: docs,
294 Session: sess,
295 Clock: clock,
296 WriteConcern: wc,
297 }
298
299 if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
300 if cmd.Session != nil {
301 cmd.Session.RetryWrite = false
302 }
303 return delete(ctx, cmd, ss, nil)
304 }
305
306 cmd.Session.RetryWrite = retryWrite
307 cmd.Session.IncrementTxnNumber()
308
309 res, origErr := delete(ctx, cmd, ss, nil)
310 if shouldRetry(origErr, res.WriteConcernError) {
311 newServer, err := topo.SelectServer(ctx, selector)
312 if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
313 return res, origErr
314 }
315
316 return delete(ctx, cmd, newServer, origErr)
317 }
318
319 return res, origErr
320}
321
322func runUpdate(
323 ctx context.Context,
324 ns command.Namespace,
325 topo *topology.Topology,
326 selector description.ServerSelector,
327 ss *topology.SelectedServer,
328 sess *session.Client,
329 clock *session.ClusterClock,
330 wc *writeconcern.WriteConcern,
331 retryWrite bool,
332 batch bulkWriteBatch,
333 bypassDocValidation *bool,
334 continueOnError bool,
335 registry *bsoncodec.Registry,
336) (result.Update, error) {
337 docs := make([]bsonx.Doc, len(batch.models))
338
339 for i, model := range batch.models {
340 var doc bsonx.Doc
341 var err error
342
343 if rom, ok := model.(ReplaceOneModel); ok {
344 doc, err = createUpdateDoc(rom.Filter, rom.Replacement, options.ArrayFilters{}, false, rom.UpdateModel, false,
345 registry)
346 } else if uom, ok := model.(UpdateOneModel); ok {
347 doc, err = createUpdateDoc(uom.Filter, uom.Update, uom.ArrayFilters, uom.ArrayFiltersSet, uom.UpdateModel, false,
348 registry)
349 } else if umm, ok := model.(UpdateManyModel); ok {
350 doc, err = createUpdateDoc(umm.Filter, umm.Update, umm.ArrayFilters, umm.ArrayFiltersSet, umm.UpdateModel, true,
351 registry)
352 }
353
354 if err != nil {
355 return result.Update{}, err
356 }
357
358 docs[i] = doc
359 }
360
361 cmd := command.Update{
362 ContinueOnError: continueOnError,
363 NS: ns,
364 Docs: docs,
365 Session: sess,
366 Clock: clock,
367 WriteConcern: wc,
368 }
369 if bypassDocValidation != nil {
370 // TODO this is temporary!
371 cmd.Opts = []bsonx.Elem{{"bypassDocumentValidation", bsonx.Boolean(*bypassDocValidation)}}
372 //cmd.Opts = []option.UpdateOptioner{option.OptBypassDocumentValidation(bypassDocValidation)}
373 }
374
375 if !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) || !retryWrite || !batch.canRetry {
376 if cmd.Session != nil {
377 cmd.Session.RetryWrite = false
378 }
379 return update(ctx, cmd, ss, nil)
380 }
381
382 cmd.Session.RetryWrite = retryWrite
383 cmd.Session.IncrementTxnNumber()
384
385 res, origErr := update(ctx, cmd, ss, nil)
386 if shouldRetry(origErr, res.WriteConcernError) {
387 newServer, err := topo.SelectServer(ctx, selector)
388 if err != nil || !retrySupported(topo, ss.Description(), cmd.Session, cmd.WriteConcern) {
389 return res, origErr
390 }
391
392 return update(ctx, cmd, newServer, origErr)
393 }
394
395 return res, origErr
396}
397
398func verifyOptions(models []WriteModel, ss *topology.SelectedServer) error {
399 maxVersion := ss.Description().WireVersion.Max
400 // 3.4 is wire version 5
401 // 3.6 is wire version 6
402
403 for _, model := range models {
404 var collationSet bool
405 var afSet bool // arrayFilters
406
407 switch converted := model.(type) {
408 case DeleteOneModel:
409 collationSet = converted.Collation != nil
410 case DeleteManyModel:
411 collationSet = converted.Collation != nil
412 case ReplaceOneModel:
413 collationSet = converted.Collation != nil
414 case UpdateOneModel:
415 afSet = converted.ArrayFiltersSet
416 collationSet = converted.Collation != nil
417 case UpdateManyModel:
418 afSet = converted.ArrayFiltersSet
419 collationSet = converted.Collation != nil
420 }
421
422 if afSet && maxVersion < 6 {
423 return ErrArrayFilters
424 }
425
426 if collationSet && maxVersion < 5 {
427 return ErrCollation
428 }
429 }
430
431 return nil
432}
433
434func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
435 if ordered {
436 return createOrderedBatches(models)
437 }
438
439 batches := make([]bulkWriteBatch, 3)
440 var i int
441 for i = 0; i < 3; i++ {
442 batches[i].canRetry = true
443 }
444
445 var numBatches int // number of batches used. can't use len(batches) because it's set to 3
446 insertInd := -1
447 updateInd := -1
448 deleteInd := -1
449
450 for _, model := range models {
451 switch converted := model.(type) {
452 case InsertOneModel:
453 if insertInd == -1 {
454 // this is the first InsertOneModel
455 insertInd = numBatches
456 numBatches++
457 }
458
459 batches[insertInd].models = append(batches[insertInd].models, model)
460 case DeleteOneModel, DeleteManyModel:
461 if deleteInd == -1 {
462 deleteInd = numBatches
463 numBatches++
464 }
465
466 batches[deleteInd].models = append(batches[deleteInd].models, model)
467 if _, ok := converted.(DeleteManyModel); ok {
468 batches[deleteInd].canRetry = false
469 }
470 case ReplaceOneModel, UpdateOneModel, UpdateManyModel:
471 if updateInd == -1 {
472 updateInd = numBatches
473 numBatches++
474 }
475
476 batches[updateInd].models = append(batches[updateInd].models, model)
477 if _, ok := converted.(UpdateManyModel); ok {
478 batches[updateInd].canRetry = false
479 }
480 }
481 }
482
483 return batches
484}
485
486func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
487 var batches []bulkWriteBatch
488 var prevKind command.WriteCommandKind = -1
489 i := -1 // batch index
490
491 for _, model := range models {
492 var createNewBatch bool
493 var canRetry bool
494 var newKind command.WriteCommandKind
495
496 switch model.(type) {
497 case InsertOneModel:
498 createNewBatch = prevKind != command.InsertCommand
499 canRetry = true
500 newKind = command.InsertCommand
501 case DeleteOneModel:
502 createNewBatch = prevKind != command.DeleteCommand
503 canRetry = true
504 newKind = command.DeleteCommand
505 case DeleteManyModel:
506 createNewBatch = prevKind != command.DeleteCommand
507 newKind = command.DeleteCommand
508 case ReplaceOneModel, UpdateOneModel:
509 createNewBatch = prevKind != command.UpdateCommand
510 canRetry = true
511 newKind = command.UpdateCommand
512 case UpdateManyModel:
513 createNewBatch = prevKind != command.UpdateCommand
514 newKind = command.UpdateCommand
515 }
516
517 if createNewBatch {
518 batches = append(batches, bulkWriteBatch{
519 models: []WriteModel{model},
520 canRetry: canRetry,
521 })
522 i++
523 } else {
524 batches[i].models = append(batches[i].models, model)
525 if !canRetry {
526 batches[i].canRetry = false // don't make it true if it was already false
527 }
528 }
529
530 prevKind = newKind
531 }
532
533 return batches
534}
535
536func shouldRetry(cmdErr error, wcErr *result.WriteConcernError) bool {
537 if cerr, ok := cmdErr.(command.Error); ok && cerr.Retryable() ||
538 wcErr != nil && command.IsWriteConcernErrorRetryable(wcErr) {
539 return true
540 }
541
542 return false
543}
544
545func createUpdateDoc(
546 filter interface{},
547 update interface{},
548 arrayFilters options.ArrayFilters,
549 arrayFiltersSet bool,
550 updateModel UpdateModel,
551 multi bool,
552 registry *bsoncodec.Registry,
553) (bsonx.Doc, error) {
554 f, err := interfaceToDocument(filter, registry)
555 if err != nil {
556 return nil, err
557 }
558
559 u, err := interfaceToDocument(update, registry)
560 if err != nil {
561 return nil, err
562 }
563
564 doc := bsonx.Doc{
565 {"q", bsonx.Document(f)},
566 {"u", bsonx.Document(u)},
567 {"multi", bsonx.Boolean(multi)},
568 }
569
570 if arrayFiltersSet {
571 arr, err := arrayFilters.ToArray()
572 if err != nil {
573 return nil, err
574 }
575 doc = append(doc, bsonx.Elem{"arrayFilters", bsonx.Array(arr)})
576 }
577
578 if updateModel.Collation != nil {
579 doc = append(doc, bsonx.Elem{"collation", bsonx.Document(updateModel.Collation.ToDocument())})
580 }
581
582 if updateModel.UpsertSet {
583 doc = append(doc, bsonx.Elem{"upsert", bsonx.Boolean(updateModel.Upsert)})
584 }
585
586 return doc, nil
587}
588
589func createDeleteDoc(
590 filter interface{},
591 collation *options.Collation,
592 many bool,
593 registry *bsoncodec.Registry,
594) (bsonx.Doc, error) {
595 f, err := interfaceToDocument(filter, registry)
596 if err != nil {
597 return nil, err
598 }
599
600 var limit int32 = 1
601 if many {
602 limit = 0
603 }
604
605 doc := bsonx.Doc{
606 {"q", bsonx.Document(f)},
607 {"limit", bsonx.Int32(limit)},
608 }
609
610 if collation != nil {
611 doc = append(doc, bsonx.Elem{"collation", bsonx.Document(collation.ToDocument())})
612 }
613
614 return doc, nil
615}
616
617func mergeResults(aggResult *result.BulkWrite, newResult result.BulkWrite, opIndex int64) {
618 aggResult.InsertedCount += newResult.InsertedCount
619 aggResult.MatchedCount += newResult.MatchedCount
620 aggResult.ModifiedCount += newResult.ModifiedCount
621 aggResult.DeletedCount += newResult.DeletedCount
622 aggResult.UpsertedCount += newResult.UpsertedCount
623
624 for index, upsertID := range newResult.UpsertedIDs {
625 aggResult.UpsertedIDs[index+opIndex] = upsertID
626 }
627}