blob: da946c3867826a2571f72c7e945bcc1d8d4994b8 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001package driver
2
3import (
4 "context"
5 "errors"
6 "fmt"
7
8 "github.com/mongodb/mongo-go-driver/bson"
9 "github.com/mongodb/mongo-go-driver/bson/bsontype"
10 "github.com/mongodb/mongo-go-driver/x/bsonx"
11 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
12 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
13 "github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
14 "github.com/mongodb/mongo-go-driver/x/network/command"
15 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
16)
17
18// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
19// of one at a time. An individual document cursor can be built on top of this batch cursor.
20type BatchCursor struct {
21 clientSession *session.Client
22 clock *session.ClusterClock
23 namespace command.Namespace
24 id int64
25 err error
26 server *topology.Server
27 opts []bsonx.Elem
28 currentBatch []byte
29 firstBatch bool
30 batchNumber int
31
32 // legacy server (< 3.2) fields
33 batchSize int32
34 limit int32
35 numReturned int32 // number of docs returned by server
36}
37
38// NewBatchCursor creates a new BatchCursor from the provided parameters.
39func NewBatchCursor(result bsoncore.Document, clientSession *session.Client, clock *session.ClusterClock, server *topology.Server, opts ...bsonx.Elem) (*BatchCursor, error) {
40 cur, err := result.LookupErr("cursor")
41 if err != nil {
42 return nil, err
43 }
44 if cur.Type != bson.TypeEmbeddedDocument {
45 return nil, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
46 }
47
48 elems, err := cur.Document().Elements()
49 if err != nil {
50 return nil, err
51 }
52 bc := &BatchCursor{
53 clientSession: clientSession,
54 clock: clock,
55 server: server,
56 opts: opts,
57 firstBatch: true,
58 }
59
60 var ok bool
61 for _, elem := range elems {
62 switch elem.Key() {
63 case "firstBatch":
64 arr, ok := elem.Value().ArrayOK()
65 if !ok {
66 return nil, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
67 }
68 vals, err := arr.Values()
69 if err != nil {
70 return nil, err
71 }
72
73 for _, val := range vals {
74 if val.Type != bsontype.EmbeddedDocument {
75 return nil, fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type)
76 }
77 bc.currentBatch = append(bc.currentBatch, val.Data...)
78 }
79 case "ns":
80 if elem.Value().Type != bson.TypeString {
81 return nil, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
82 }
83 namespace := command.ParseNamespace(elem.Value().StringValue())
84 err = namespace.Validate()
85 if err != nil {
86 return nil, err
87 }
88 bc.namespace = namespace
89 case "id":
90 bc.id, ok = elem.Value().Int64OK()
91 if !ok {
92 return nil, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
93 }
94 }
95 }
96
97 // close session if everything fits in first batch
98 if bc.id == 0 {
99 bc.closeImplicitSession()
100 }
101 return bc, nil
102}
103
104// NewEmptyBatchCursor returns a batch cursor that is empty.
105func NewEmptyBatchCursor() *BatchCursor {
106 return &BatchCursor{}
107}
108
109// NewLegacyBatchCursor creates a new BatchCursor for server versions 3.0 and below from the
110// provided parameters.
111//
112// TODO(GODRIVER-617): The batch parameter here should be []bsoncore.Document. Change it to this
113// once we have the new wiremessage package that uses bsoncore instead of bson.
114func NewLegacyBatchCursor(ns command.Namespace, cursorID int64, batch []bson.Raw, limit int32, batchSize int32, server *topology.Server) (*BatchCursor, error) {
115 bc := &BatchCursor{
116 id: cursorID,
117 server: server,
118 namespace: ns,
119 limit: limit,
120 batchSize: batchSize,
121 numReturned: int32(len(batch)),
122 firstBatch: true,
123 }
124
125 // take as many documents from the batch as needed
126 firstBatchSize := int32(len(batch))
127 if limit != 0 && limit < firstBatchSize {
128 firstBatchSize = limit
129 }
130 batch = batch[:firstBatchSize]
131 for _, doc := range batch {
132 bc.currentBatch = append(bc.currentBatch, doc...)
133 }
134
135 return bc, nil
136}
137
138// ID returns the cursor ID for this batch cursor.
139func (bc *BatchCursor) ID() int64 {
140 return bc.id
141}
142
143// Next indicates if there is another batch available. Returning false does not necessarily indicate
144// that the cursor is closed. This method will return false when an empty batch is returned.
145//
146// If Next returns true, there is a valid batch of documents available. If Next returns false, there
147// is not a valid batch of documents available.
148func (bc *BatchCursor) Next(ctx context.Context) bool {
149 if ctx == nil {
150 ctx = context.Background()
151 }
152
153 if bc.firstBatch {
154 bc.firstBatch = false
155 return true
156 }
157
158 if bc.id == 0 || bc.server == nil {
159 return false
160 }
161
162 if bc.legacy() {
163 bc.legacyGetMore(ctx)
164 } else {
165 bc.getMore(ctx)
166 }
167
168 return len(bc.currentBatch) > 0
169}
170
171// Batch will append the current batch of documents to dst. RequiredBytes can be called to determine
172// the length of the current batch of documents.
173//
174// If there is no batch available, this method does nothing.
175func (bc *BatchCursor) Batch(dst []byte) []byte { return append(dst, bc.currentBatch...) }
176
177// RequiredBytes returns the number of bytes required for the current batch.
178func (bc *BatchCursor) RequiredBytes() int { return len(bc.currentBatch) }
179
180// Err returns the latest error encountered.
181func (bc *BatchCursor) Err() error { return bc.err }
182
183// Close closes this batch cursor.
184func (bc *BatchCursor) Close(ctx context.Context) error {
185 if ctx == nil {
186 ctx = context.Background()
187 }
188
189 if bc.server == nil {
190 return nil
191 }
192
193 if bc.legacy() {
194 return bc.legacyKillCursor(ctx)
195 }
196
197 defer bc.closeImplicitSession()
198 conn, err := bc.server.Connection(ctx)
199 if err != nil {
200 return err
201 }
202
203 _, err = (&command.KillCursors{
204 Clock: bc.clock,
205 NS: bc.namespace,
206 IDs: []int64{bc.id},
207 }).RoundTrip(ctx, bc.server.SelectedDescription(), conn)
208 if err != nil {
209 _ = conn.Close() // The command response error is more important here
210 return err
211 }
212
213 bc.id = 0
214 return conn.Close()
215}
216
217func (bc *BatchCursor) closeImplicitSession() {
218 if bc.clientSession != nil && bc.clientSession.SessionType == session.Implicit {
219 bc.clientSession.EndSession()
220 }
221}
222
223func (bc *BatchCursor) clearBatch() {
224 bc.currentBatch = bc.currentBatch[:0]
225}
226
227func (bc *BatchCursor) getMore(ctx context.Context) {
228 bc.clearBatch()
229 if bc.id == 0 {
230 return
231 }
232
233 conn, err := bc.server.Connection(ctx)
234 if err != nil {
235 bc.err = err
236 return
237 }
238
239 response, err := (&command.GetMore{
240 Clock: bc.clock,
241 ID: bc.id,
242 NS: bc.namespace,
243 Opts: bc.opts,
244 Session: bc.clientSession,
245 }).RoundTrip(ctx, bc.server.SelectedDescription(), conn)
246 if err != nil {
247 _ = conn.Close() // The command response error is more important here
248 bc.err = err
249 return
250 }
251
252 err = conn.Close()
253 if err != nil {
254 bc.err = err
255 return
256 }
257
258 id, err := response.LookupErr("cursor", "id")
259 if err != nil {
260 bc.err = err
261 return
262 }
263 var ok bool
264 bc.id, ok = id.Int64OK()
265 if !ok {
266 bc.err = fmt.Errorf("BSON Type %s is not %s", id.Type, bson.TypeInt64)
267 return
268 }
269
270 // if this is the last getMore, close the session
271 if bc.id == 0 {
272 bc.closeImplicitSession()
273 }
274
275 batch, err := response.LookupErr("cursor", "nextBatch")
276 if err != nil {
277 bc.err = err
278 return
279 }
280 var arr bson.Raw
281 arr, ok = batch.ArrayOK()
282 if !ok {
283 bc.err = fmt.Errorf("BSON Type %s is not %s", batch.Type, bson.TypeArray)
284 return
285 }
286 vals, err := arr.Values()
287 if err != nil {
288 bc.err = err
289 return
290 }
291
292 for _, val := range vals {
293 if val.Type != bsontype.EmbeddedDocument {
294 bc.err = fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type)
295 bc.currentBatch = bc.currentBatch[:0] // don't return a batch on error
296 return
297 }
298 bc.currentBatch = append(bc.currentBatch, val.Value...)
299 }
300
301 return
302}
303
304func (bc *BatchCursor) legacy() bool {
305 return bc.server.Description().WireVersion == nil || bc.server.Description().WireVersion.Max < 4
306}
307
308func (bc *BatchCursor) legacyKillCursor(ctx context.Context) error {
309 conn, err := bc.server.Connection(ctx)
310 if err != nil {
311 return err
312 }
313
314 kc := wiremessage.KillCursors{
315 NumberOfCursorIDs: 1,
316 CursorIDs: []int64{bc.id},
317 CollectionName: bc.namespace.Collection,
318 DatabaseName: bc.namespace.DB,
319 }
320
321 err = conn.WriteWireMessage(ctx, kc)
322 if err != nil {
323 _ = conn.Close()
324 return err
325 }
326
327 err = conn.Close() // no reply from OP_KILL_CURSORS
328 if err != nil {
329 return err
330 }
331
332 bc.id = 0
333 bc.clearBatch()
334 return nil
335}
336
337func (bc *BatchCursor) legacyGetMore(ctx context.Context) {
338 bc.clearBatch()
339 if bc.id == 0 {
340 return
341 }
342
343 conn, err := bc.server.Connection(ctx)
344 if err != nil {
345 bc.err = err
346 return
347 }
348
349 numToReturn := bc.batchSize
350 if bc.limit != 0 && bc.numReturned+bc.batchSize > bc.limit {
351 numToReturn = bc.limit - bc.numReturned
352 }
353 gm := wiremessage.GetMore{
354 FullCollectionName: bc.namespace.DB + "." + bc.namespace.Collection,
355 CursorID: bc.id,
356 NumberToReturn: numToReturn,
357 }
358
359 err = conn.WriteWireMessage(ctx, gm)
360 if err != nil {
361 _ = conn.Close()
362 bc.err = err
363 return
364 }
365
366 response, err := conn.ReadWireMessage(ctx)
367 if err != nil {
368 _ = conn.Close()
369 bc.err = err
370 return
371 }
372
373 err = conn.Close()
374 if err != nil {
375 bc.err = err
376 return
377 }
378
379 reply, ok := response.(wiremessage.Reply)
380 if !ok {
381 bc.err = errors.New("did not receive OP_REPLY response")
382 return
383 }
384
385 err = validateGetMoreReply(reply)
386 if err != nil {
387 bc.err = err
388 return
389 }
390
391 bc.id = reply.CursorID
392 bc.numReturned += reply.NumberReturned
393 if bc.limit != 0 && bc.numReturned >= bc.limit {
394 err = bc.Close(ctx)
395 if err != nil {
396 bc.err = err
397 return
398 }
399 }
400
401 for _, doc := range reply.Documents {
402 bc.currentBatch = append(bc.currentBatch, doc...)
403 }
404}
405
406func validateGetMoreReply(reply wiremessage.Reply) error {
407 if int(reply.NumberReturned) != len(reply.Documents) {
408 return command.NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of returned documents", nil)
409 }
410
411 if reply.ResponseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound {
412 return command.QueryFailureError{
413 Message: "query failure - cursor not found",
414 }
415 }
416 if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
417 return command.QueryFailureError{
418 Message: "query failure",
419 Response: reply.Documents[0],
420 }
421 }
422
423 return nil
424}