blob: 29470ae73585a969989cb2f3d4b59440d88863b8 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
14 "github.com/mongodb/mongo-go-driver/x/bsonx"
15 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
16 "github.com/mongodb/mongo-go-driver/x/network/description"
17 "github.com/mongodb/mongo-go-driver/x/network/result"
18 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
19)
20
21// Update represents the update command.
22//
23// The update command updates a set of documents with the database.
24type Update struct {
25 ContinueOnError bool
26 Clock *session.ClusterClock
27 NS Namespace
28 Docs []bsonx.Doc
29 Opts []bsonx.Elem
30 WriteConcern *writeconcern.WriteConcern
31 Session *session.Client
32
33 batches []*WriteBatch
34 result result.Update
35 err error
36}
37
38// Encode will encode this command into a wire message for the given server description.
39func (u *Update) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
40 err := u.encode(desc)
41 if err != nil {
42 return nil, err
43 }
44
45 return batchesToWireMessage(u.batches, desc)
46}
47
48func (u *Update) encode(desc description.SelectedServer) error {
49 batches, err := splitBatches(u.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
50 if err != nil {
51 return err
52 }
53
54 for _, docs := range batches {
55 cmd, err := u.encodeBatch(docs, desc)
56 if err != nil {
57 return err
58 }
59
60 u.batches = append(u.batches, cmd)
61 }
62
63 return nil
64}
65
66func (u *Update) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
67 copyDocs := make([]bsonx.Doc, 0, len(docs)) // copy of all the documents
68 for _, doc := range docs {
69 newDoc := doc.Copy()
70 copyDocs = append(copyDocs, newDoc)
71 }
72
73 var options []bsonx.Elem
74 for _, opt := range u.Opts {
75 switch opt.Key {
76 case "upsert", "collation", "arrayFilters":
77 // options that are encoded on each individual document
78 for idx := range copyDocs {
79 copyDocs[idx] = append(copyDocs[idx], opt)
80 }
81 default:
82 options = append(options, opt)
83 }
84 }
85
86 command, err := encodeBatch(copyDocs, options, UpdateCommand, u.NS.Collection)
87 if err != nil {
88 return nil, err
89 }
90
91 return &WriteBatch{
92 &Write{
93 Clock: u.Clock,
94 DB: u.NS.DB,
95 Command: command,
96 WriteConcern: u.WriteConcern,
97 Session: u.Session,
98 },
99 len(docs),
100 }, nil
101}
102
103// Decode will decode the wire message using the provided server description. Errors during decoding
104// are deferred until either the Result or Err methods are called.
105func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Update {
106 rdr, err := (&Write{}).Decode(desc, wm).Result()
107 if err != nil {
108 u.err = err
109 return u
110 }
111 return u.decode(desc, rdr)
112}
113
114func (u *Update) decode(desc description.SelectedServer, rdr bson.Raw) *Update {
115 u.err = bson.Unmarshal(rdr, &u.result)
116 return u
117}
118
119// Result returns the result of a decoded wire message and server description.
120func (u *Update) Result() (result.Update, error) {
121 if u.err != nil {
122 return result.Update{}, u.err
123 }
124 return u.result, nil
125}
126
127// Err returns the error set on this command.
128func (u *Update) Err() error { return u.err }
129
130// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
131func (u *Update) RoundTrip(
132 ctx context.Context,
133 desc description.SelectedServer,
134 rw wiremessage.ReadWriter,
135) (result.Update, error) {
136 if u.batches == nil {
137 err := u.encode(desc)
138 if err != nil {
139 return result.Update{}, err
140 }
141 }
142
143 r, batches, err := roundTripBatches(
144 ctx, desc, rw,
145 u.batches,
146 u.ContinueOnError,
147 u.Session,
148 UpdateCommand,
149 )
150
151 // if there are leftover batches, save them for retry
152 if batches != nil {
153 u.batches = batches
154 }
155
156 if err != nil {
157 return result.Update{}, err
158 }
159
160 return r.(result.Update), nil
161}