blob: 419a78b35753723dd7584d44a03585cca2096414 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11 "errors"
12
13 "github.com/mongodb/mongo-go-driver/bson"
14 "github.com/mongodb/mongo-go-driver/mongo/readconcern"
15 "github.com/mongodb/mongo-go-driver/mongo/readpref"
16 "github.com/mongodb/mongo-go-driver/x/bsonx"
17 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
18 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
19 "github.com/mongodb/mongo-go-driver/x/network/description"
20 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
21)
22
23// Count represents the count command.
24//
25// The count command counts how many documents in a collection match the given query.
26type Count struct {
27 NS Namespace
28 Query bsonx.Doc
29 Opts []bsonx.Elem
30 ReadPref *readpref.ReadPref
31 ReadConcern *readconcern.ReadConcern
32 Clock *session.ClusterClock
33 Session *session.Client
34
35 result int64
36 err error
37}
38
39// Encode will encode this command into a wire message for the given server description.
40func (c *Count) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
41 cmd, err := c.encode(desc)
42 if err != nil {
43 return nil, err
44 }
45
46 return cmd.Encode(desc)
47}
48
49func (c *Count) encode(desc description.SelectedServer) (*Read, error) {
50 if err := c.NS.Validate(); err != nil {
51 return nil, err
52 }
53
54 command := bsonx.Doc{{"count", bsonx.String(c.NS.Collection)}, {"query", bsonx.Document(c.Query)}}
55 command = append(command, c.Opts...)
56
57 return &Read{
58 Clock: c.Clock,
59 DB: c.NS.DB,
60 ReadPref: c.ReadPref,
61 Command: command,
62 ReadConcern: c.ReadConcern,
63 Session: c.Session,
64 }, nil
65}
66
67// Decode will decode the wire message using the provided server description. Errors during decoding
68// are deferred until either the Result or Err methods are called.
69func (c *Count) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Count {
70 rdr, err := (&Read{}).Decode(desc, wm).Result()
71 if err != nil {
72 c.err = err
73 return c
74 }
75
76 return c.decode(desc, rdr)
77}
78
79func (c *Count) decode(desc description.SelectedServer, rdr bson.Raw) *Count {
80 val, err := rdr.LookupErr("n")
81 switch {
82 case err == bsoncore.ErrElementNotFound:
83 c.err = errors.New("invalid response from server, no 'n' field")
84 return c
85 case err != nil:
86 c.err = err
87 return c
88 }
89
90 switch val.Type {
91 case bson.TypeInt32:
92 c.result = int64(val.Int32())
93 case bson.TypeInt64:
94 c.result = val.Int64()
95 case bson.TypeDouble:
96 c.result = int64(val.Double())
97 default:
98 c.err = errors.New("invalid response from server, value field is not a number")
99 }
100
101 return c
102}
103
104// Result returns the result of a decoded wire message and server description.
105func (c *Count) Result() (int64, error) {
106 if c.err != nil {
107 return 0, c.err
108 }
109 return c.result, nil
110}
111
112// Err returns the error set on this command.
113func (c *Count) Err() error { return c.err }
114
115// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
116func (c *Count) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
117 cmd, err := c.encode(desc)
118 if err != nil {
119 return 0, err
120 }
121
122 rdr, err := cmd.RoundTrip(ctx, desc, rw)
123 if err != nil {
124 return 0, err
125 }
126
127 return c.decode(desc, rdr).Result()
128}