blob: da946c3867826a2571f72c7e945bcc1d8d4994b8 [file] [log] [blame]
package driver
import (
"context"
"errors"
"fmt"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/bson/bsontype"
"github.com/mongodb/mongo-go-driver/x/bsonx"
"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
"github.com/mongodb/mongo-go-driver/x/network/command"
"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
)
// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
// of one at a time. An individual document cursor can be built on top of this batch cursor.
type BatchCursor struct {
clientSession *session.Client
clock *session.ClusterClock
namespace command.Namespace
id int64
err error
server *topology.Server
opts []bsonx.Elem
currentBatch []byte
firstBatch bool
batchNumber int
// legacy server (< 3.2) fields
batchSize int32
limit int32
numReturned int32 // number of docs returned by server
}
// NewBatchCursor creates a new BatchCursor from the provided parameters.
func NewBatchCursor(result bsoncore.Document, clientSession *session.Client, clock *session.ClusterClock, server *topology.Server, opts ...bsonx.Elem) (*BatchCursor, error) {
cur, err := result.LookupErr("cursor")
if err != nil {
return nil, err
}
if cur.Type != bson.TypeEmbeddedDocument {
return nil, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
}
elems, err := cur.Document().Elements()
if err != nil {
return nil, err
}
bc := &BatchCursor{
clientSession: clientSession,
clock: clock,
server: server,
opts: opts,
firstBatch: true,
}
var ok bool
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok := elem.Value().ArrayOK()
if !ok {
return nil, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
}
vals, err := arr.Values()
if err != nil {
return nil, err
}
for _, val := range vals {
if val.Type != bsontype.EmbeddedDocument {
return nil, fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type)
}
bc.currentBatch = append(bc.currentBatch, val.Data...)
}
case "ns":
if elem.Value().Type != bson.TypeString {
return nil, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
}
namespace := command.ParseNamespace(elem.Value().StringValue())
err = namespace.Validate()
if err != nil {
return nil, err
}
bc.namespace = namespace
case "id":
bc.id, ok = elem.Value().Int64OK()
if !ok {
return nil, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
}
}
// close session if everything fits in first batch
if bc.id == 0 {
bc.closeImplicitSession()
}
return bc, nil
}
// NewEmptyBatchCursor returns a batch cursor that is empty.
func NewEmptyBatchCursor() *BatchCursor {
return &BatchCursor{}
}
// NewLegacyBatchCursor creates a new BatchCursor for server versions 3.0 and below from the
// provided parameters.
//
// TODO(GODRIVER-617): The batch parameter here should be []bsoncore.Document. Change it to this
// once we have the new wiremessage package that uses bsoncore instead of bson.
func NewLegacyBatchCursor(ns command.Namespace, cursorID int64, batch []bson.Raw, limit int32, batchSize int32, server *topology.Server) (*BatchCursor, error) {
bc := &BatchCursor{
id: cursorID,
server: server,
namespace: ns,
limit: limit,
batchSize: batchSize,
numReturned: int32(len(batch)),
firstBatch: true,
}
// take as many documents from the batch as needed
firstBatchSize := int32(len(batch))
if limit != 0 && limit < firstBatchSize {
firstBatchSize = limit
}
batch = batch[:firstBatchSize]
for _, doc := range batch {
bc.currentBatch = append(bc.currentBatch, doc...)
}
return bc, nil
}
// ID returns the cursor ID for this batch cursor.
func (bc *BatchCursor) ID() int64 {
return bc.id
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (bc *BatchCursor) Next(ctx context.Context) bool {
if ctx == nil {
ctx = context.Background()
}
if bc.firstBatch {
bc.firstBatch = false
return true
}
if bc.id == 0 || bc.server == nil {
return false
}
if bc.legacy() {
bc.legacyGetMore(ctx)
} else {
bc.getMore(ctx)
}
return len(bc.currentBatch) > 0
}
// Batch will append the current batch of documents to dst. RequiredBytes can be called to determine
// the length of the current batch of documents.
//
// If there is no batch available, this method does nothing.
func (bc *BatchCursor) Batch(dst []byte) []byte { return append(dst, bc.currentBatch...) }
// RequiredBytes returns the number of bytes required for the current batch.
func (bc *BatchCursor) RequiredBytes() int { return len(bc.currentBatch) }
// Err returns the latest error encountered.
func (bc *BatchCursor) Err() error { return bc.err }
// Close closes this batch cursor.
func (bc *BatchCursor) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
if bc.server == nil {
return nil
}
if bc.legacy() {
return bc.legacyKillCursor(ctx)
}
defer bc.closeImplicitSession()
conn, err := bc.server.Connection(ctx)
if err != nil {
return err
}
_, err = (&command.KillCursors{
Clock: bc.clock,
NS: bc.namespace,
IDs: []int64{bc.id},
}).RoundTrip(ctx, bc.server.SelectedDescription(), conn)
if err != nil {
_ = conn.Close() // The command response error is more important here
return err
}
bc.id = 0
return conn.Close()
}
func (bc *BatchCursor) closeImplicitSession() {
if bc.clientSession != nil && bc.clientSession.SessionType == session.Implicit {
bc.clientSession.EndSession()
}
}
func (bc *BatchCursor) clearBatch() {
bc.currentBatch = bc.currentBatch[:0]
}
func (bc *BatchCursor) getMore(ctx context.Context) {
bc.clearBatch()
if bc.id == 0 {
return
}
conn, err := bc.server.Connection(ctx)
if err != nil {
bc.err = err
return
}
response, err := (&command.GetMore{
Clock: bc.clock,
ID: bc.id,
NS: bc.namespace,
Opts: bc.opts,
Session: bc.clientSession,
}).RoundTrip(ctx, bc.server.SelectedDescription(), conn)
if err != nil {
_ = conn.Close() // The command response error is more important here
bc.err = err
return
}
err = conn.Close()
if err != nil {
bc.err = err
return
}
id, err := response.LookupErr("cursor", "id")
if err != nil {
bc.err = err
return
}
var ok bool
bc.id, ok = id.Int64OK()
if !ok {
bc.err = fmt.Errorf("BSON Type %s is not %s", id.Type, bson.TypeInt64)
return
}
// if this is the last getMore, close the session
if bc.id == 0 {
bc.closeImplicitSession()
}
batch, err := response.LookupErr("cursor", "nextBatch")
if err != nil {
bc.err = err
return
}
var arr bson.Raw
arr, ok = batch.ArrayOK()
if !ok {
bc.err = fmt.Errorf("BSON Type %s is not %s", batch.Type, bson.TypeArray)
return
}
vals, err := arr.Values()
if err != nil {
bc.err = err
return
}
for _, val := range vals {
if val.Type != bsontype.EmbeddedDocument {
bc.err = fmt.Errorf("element of cursor batch is not a document, but at %s", val.Type)
bc.currentBatch = bc.currentBatch[:0] // don't return a batch on error
return
}
bc.currentBatch = append(bc.currentBatch, val.Value...)
}
return
}
func (bc *BatchCursor) legacy() bool {
return bc.server.Description().WireVersion == nil || bc.server.Description().WireVersion.Max < 4
}
func (bc *BatchCursor) legacyKillCursor(ctx context.Context) error {
conn, err := bc.server.Connection(ctx)
if err != nil {
return err
}
kc := wiremessage.KillCursors{
NumberOfCursorIDs: 1,
CursorIDs: []int64{bc.id},
CollectionName: bc.namespace.Collection,
DatabaseName: bc.namespace.DB,
}
err = conn.WriteWireMessage(ctx, kc)
if err != nil {
_ = conn.Close()
return err
}
err = conn.Close() // no reply from OP_KILL_CURSORS
if err != nil {
return err
}
bc.id = 0
bc.clearBatch()
return nil
}
func (bc *BatchCursor) legacyGetMore(ctx context.Context) {
bc.clearBatch()
if bc.id == 0 {
return
}
conn, err := bc.server.Connection(ctx)
if err != nil {
bc.err = err
return
}
numToReturn := bc.batchSize
if bc.limit != 0 && bc.numReturned+bc.batchSize > bc.limit {
numToReturn = bc.limit - bc.numReturned
}
gm := wiremessage.GetMore{
FullCollectionName: bc.namespace.DB + "." + bc.namespace.Collection,
CursorID: bc.id,
NumberToReturn: numToReturn,
}
err = conn.WriteWireMessage(ctx, gm)
if err != nil {
_ = conn.Close()
bc.err = err
return
}
response, err := conn.ReadWireMessage(ctx)
if err != nil {
_ = conn.Close()
bc.err = err
return
}
err = conn.Close()
if err != nil {
bc.err = err
return
}
reply, ok := response.(wiremessage.Reply)
if !ok {
bc.err = errors.New("did not receive OP_REPLY response")
return
}
err = validateGetMoreReply(reply)
if err != nil {
bc.err = err
return
}
bc.id = reply.CursorID
bc.numReturned += reply.NumberReturned
if bc.limit != 0 && bc.numReturned >= bc.limit {
err = bc.Close(ctx)
if err != nil {
bc.err = err
return
}
}
for _, doc := range reply.Documents {
bc.currentBatch = append(bc.currentBatch, doc...)
}
}
func validateGetMoreReply(reply wiremessage.Reply) error {
if int(reply.NumberReturned) != len(reply.Documents) {
return command.NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of returned documents", nil)
}
if reply.ResponseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound {
return command.QueryFailureError{
Message: "query failure - cursor not found",
}
}
if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
return command.QueryFailureError{
Message: "query failure",
Response: reply.Documents[0],
}
}
return nil
}