blob: 3787faefb9106a6961f9434d3d1723f6ac29f99e [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11 "fmt"
12
13 "errors"
14
15 "github.com/mongodb/mongo-go-driver/bson"
16 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
17 "github.com/mongodb/mongo-go-driver/x/bsonx"
18 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
19 "github.com/mongodb/mongo-go-driver/x/network/description"
20 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
21)
22
23// Write represents a generic write database command.
24// This can be used to send arbitrary write commands to the database.
25type Write struct {
26 DB string
27 Command bsonx.Doc
28 WriteConcern *writeconcern.WriteConcern
29 Clock *session.ClusterClock
30 Session *session.Client
31
32 result bson.Raw
33 err error
34}
35
36// Encode c as OP_MSG
37func (w *Write) encodeOpMsg(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
38 var arr bsonx.Arr
39 var identifier string
40
41 cmd, arr, identifier = opmsgRemoveArray(cmd)
42
43 msg := wiremessage.Msg{
44 MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
45 Sections: make([]wiremessage.Section, 0),
46 }
47
48 fullDocRdr, err := opmsgAddGlobals(cmd, w.DB, nil)
49 if err != nil {
50 return nil, err
51 }
52
53 // type 0 doc
54 msg.Sections = append(msg.Sections, wiremessage.SectionBody{
55 PayloadType: wiremessage.SingleDocument,
56 Document: fullDocRdr,
57 })
58
59 // type 1 doc
60 if identifier != "" {
61 docSequence, err := opmsgCreateDocSequence(arr, identifier)
62 if err != nil {
63 return nil, err
64 }
65
66 msg.Sections = append(msg.Sections, docSequence)
67 }
68
69 // flags
70 if !writeconcern.AckWrite(w.WriteConcern) {
71 msg.FlagBits |= wiremessage.MoreToCome
72 }
73
74 return msg, nil
75}
76
77// Encode w as OP_QUERY
78func (w *Write) encodeOpQuery(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
79 rdr, err := marshalCommand(cmd)
80 if err != nil {
81 return nil, err
82 }
83
84 query := wiremessage.Query{
85 MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
86 FullCollectionName: w.DB + ".$cmd",
87 Flags: w.slaveOK(desc),
88 NumberToReturn: -1,
89 Query: rdr,
90 }
91
92 return query, nil
93}
94
95func (w *Write) slaveOK(desc description.SelectedServer) wiremessage.QueryFlag {
96 if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
97 return wiremessage.SlaveOK
98 }
99
100 return 0
101}
102
103func (w *Write) decodeOpReply(wm wiremessage.WireMessage) {
104 reply, ok := wm.(wiremessage.Reply)
105 if !ok {
106 w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
107 return
108 }
109 w.result, w.err = decodeCommandOpReply(reply)
110}
111
112func (w *Write) decodeOpMsg(wm wiremessage.WireMessage) {
113 msg, ok := wm.(wiremessage.Msg)
114 if !ok {
115 w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
116 return
117 }
118
119 w.result, w.err = decodeCommandOpMsg(msg)
120}
121
122// Encode will encode this command into a wire message for the given server description.
123func (w *Write) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
124 cmd := w.Command.Copy()
125 var err error
126 if w.Session != nil && w.Session.TransactionStarting() {
127 // Starting transactions have a read concern, even in writes.
128 cmd, err = addReadConcern(cmd, desc, nil, w.Session)
129 if err != nil {
130 return nil, err
131 }
132 }
133 cmd, err = addWriteConcern(cmd, w.WriteConcern)
134 if err != nil {
135 return nil, err
136 }
137
138 if !writeconcern.AckWrite(w.WriteConcern) {
139 // unack write with explicit session --> raise an error
140 // unack write with implicit session --> do not send session ID (implicit session shouldn't have been created
141 // in the first place)
142
143 if w.Session != nil && w.Session.SessionType == session.Explicit {
144 return nil, errors.New("explicit sessions cannot be used with unacknowledged writes")
145 }
146 } else {
147 // only encode session ID for acknowledged writes
148 cmd, err = addSessionFields(cmd, desc, w.Session)
149 if err != nil {
150 return nil, err
151 }
152 }
153
154 if w.Session != nil && w.Session.RetryWrite {
155 cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(w.Session.TxnNumber)})
156 }
157
158 cmd = addClusterTime(cmd, desc, w.Session, w.Clock)
159
160 if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
161 return w.encodeOpQuery(desc, cmd)
162 }
163
164 return w.encodeOpMsg(desc, cmd)
165}
166
167// Decode will decode the wire message using the provided server description. Errors during decoding
168// are deferred until either the Result or Err methods are called.
169func (w *Write) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Write {
170 switch wm.(type) {
171 case wiremessage.Reply:
172 w.decodeOpReply(wm)
173 default:
174 w.decodeOpMsg(wm)
175 }
176
177 if w.err != nil {
178 if _, ok := w.err.(Error); !ok {
179 return w
180 }
181 }
182
183 _ = updateClusterTimes(w.Session, w.Clock, w.result)
184
185 if writeconcern.AckWrite(w.WriteConcern) {
186 // don't update session operation time for unacknowledged write
187 _ = updateOperationTime(w.Session, w.result)
188 }
189 return w
190}
191
192// Result returns the result of a decoded wire message and server description.
193func (w *Write) Result() (bson.Raw, error) {
194 if w.err != nil {
195 return nil, w.err
196 }
197
198 return w.result, nil
199}
200
201// Err returns the error set on this command.
202func (w *Write) Err() error {
203 return w.err
204}
205
206// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriteCloser.
207func (w *Write) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
208 wm, err := w.Encode(desc)
209 if err != nil {
210 return nil, err
211 }
212
213 err = rw.WriteWireMessage(ctx, wm)
214 if err != nil {
215 if _, ok := err.(Error); ok {
216 return nil, err
217 }
218 // Connection errors are transient
219 return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
220 }
221
222 if msg, ok := wm.(wiremessage.Msg); ok {
223 // don't expect response if using OP_MSG for an unacknowledged write
224 if msg.FlagBits&wiremessage.MoreToCome > 0 {
225 return nil, ErrUnacknowledgedWrite
226 }
227 }
228
229 wm, err = rw.ReadWireMessage(ctx)
230 if err != nil {
231 if _, ok := err.(Error); ok {
232 return nil, err
233 }
234 // Connection errors are transient
235 return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
236 }
237
238 if w.Session != nil {
239 err = w.Session.UpdateUseTime()
240 if err != nil {
241 return nil, err
242 }
243 }
244 return w.Decode(desc, wm).Result()
245}