blob: a9a27f1f035d6bad5af70e607d76e1d11334d5c8 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11 "errors"
12
13 "github.com/mongodb/mongo-go-driver/bson/bsontype"
14 "github.com/mongodb/mongo-go-driver/mongo/readconcern"
15 "github.com/mongodb/mongo-go-driver/mongo/readpref"
16 "github.com/mongodb/mongo-go-driver/x/bsonx"
17 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
18 "github.com/mongodb/mongo-go-driver/x/network/description"
19 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
20)
21
22// CountDocuments represents the CountDocuments command.
23//
24// The countDocuments command counts how many documents in a collection match the given query.
25type CountDocuments struct {
26 NS Namespace
27 Pipeline bsonx.Arr
28 Opts []bsonx.Elem
29 ReadPref *readpref.ReadPref
30 ReadConcern *readconcern.ReadConcern
31 Clock *session.ClusterClock
32 Session *session.Client
33
34 result int64
35 err error
36}
37
38// Encode will encode this command into a wire message for the given server description.
39func (c *CountDocuments) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
40 if err := c.NS.Validate(); err != nil {
41 return nil, err
42 }
43 command := bsonx.Doc{{"aggregate", bsonx.String(c.NS.Collection)}, {"pipeline", bsonx.Array(c.Pipeline)}}
44
45 command = append(command, bsonx.Elem{"cursor", bsonx.Document(bsonx.Doc{})})
46 command = append(command, c.Opts...)
47
48 return (&Read{DB: c.NS.DB, ReadPref: c.ReadPref, Command: command}).Encode(desc)
49}
50
51// Decode will decode the wire message using the provided server description. Errors during decoding
52// are deferred until either the Result or Err methods are called.
53func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, wm wiremessage.WireMessage) *CountDocuments {
54 rdr, err := (&Read{}).Decode(desc, wm).Result()
55 if err != nil {
56 c.err = err
57 return c
58 }
59
60 cursor, err := rdr.LookupErr("cursor")
61 if err != nil || cursor.Type != bsontype.EmbeddedDocument {
62 c.err = errors.New("Invalid response from server, no 'cursor' field")
63 return c
64 }
65 batch, err := cursor.Document().LookupErr("firstBatch")
66 if err != nil || batch.Type != bsontype.Array {
67 c.err = errors.New("Invalid response from server, no 'firstBatch' field")
68 return c
69 }
70
71 elem, err := batch.Array().IndexErr(0)
72 if err != nil || elem.Value().Type != bsontype.EmbeddedDocument {
73 c.result = 0
74 return c
75 }
76
77 val, err := elem.Value().Document().LookupErr("n")
78 if err != nil {
79 c.err = errors.New("Invalid response from server, no 'n' field")
80 return c
81 }
82
83 switch val.Type {
84 case bsontype.Int32:
85 c.result = int64(val.Int32())
86 case bsontype.Int64:
87 c.result = val.Int64()
88 default:
89 c.err = errors.New("Invalid response from server, value field is not a number")
90 }
91
92 return c
93}
94
95// Result returns the result of a decoded wire message and server description.
96func (c *CountDocuments) Result() (int64, error) {
97 if c.err != nil {
98 return 0, c.err
99 }
100 return c.result, nil
101}
102
103// Err returns the error set on this command.
104func (c *CountDocuments) Err() error { return c.err }
105
106// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
107func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
108 wm, err := c.Encode(desc)
109 if err != nil {
110 return 0, err
111 }
112
113 err = rw.WriteWireMessage(ctx, wm)
114 if err != nil {
115 return 0, err
116 }
117 wm, err = rw.ReadWireMessage(ctx)
118 if err != nil {
119 return 0, err
120 }
121 return c.Decode(ctx, desc, wm).Result()
122}