blob: 106f583241f195ce72f567911a758297789d8bda [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package command
8
9import (
10 "context"
11
12 "github.com/mongodb/mongo-go-driver/bson"
13 "github.com/mongodb/mongo-go-driver/mongo/readconcern"
14 "github.com/mongodb/mongo-go-driver/mongo/readpref"
15 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
16 "github.com/mongodb/mongo-go-driver/x/bsonx"
17 "github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
18 "github.com/mongodb/mongo-go-driver/x/network/description"
19 "github.com/mongodb/mongo-go-driver/x/network/wiremessage"
20)
21
22// Aggregate represents the aggregate command.
23//
24// The aggregate command performs an aggregation.
25type Aggregate struct {
26 NS Namespace
27 Pipeline bsonx.Arr
28 CursorOpts []bsonx.Elem
29 Opts []bsonx.Elem
30 ReadPref *readpref.ReadPref
31 WriteConcern *writeconcern.WriteConcern
32 ReadConcern *readconcern.ReadConcern
33 Clock *session.ClusterClock
34 Session *session.Client
35
36 result bson.Raw
37 err error
38}
39
40// Encode will encode this command into a wire message for the given server description.
41func (a *Aggregate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
42 cmd, err := a.encode(desc)
43 if err != nil {
44 return nil, err
45 }
46
47 return cmd.Encode(desc)
48}
49
50func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) {
51 if err := a.NS.Validate(); err != nil {
52 return nil, err
53 }
54
55 command := bsonx.Doc{
56 {"aggregate", bsonx.String(a.NS.Collection)},
57 {"pipeline", bsonx.Array(a.Pipeline)},
58 }
59
60 cursor := bsonx.Doc{}
61 hasOutStage := a.HasDollarOut()
62
63 for _, opt := range a.Opts {
64 switch opt.Key {
65 case "batchSize":
66 if opt.Value.Int32() == 0 && hasOutStage {
67 continue
68 }
69 cursor = append(cursor, opt)
70 default:
71 command = append(command, opt)
72 }
73 }
74 command = append(command, bsonx.Elem{"cursor", bsonx.Document(cursor)})
75
76 // add write concern because it won't be added by the Read command's Encode()
77 if desc.WireVersion.Max >= 5 && hasOutStage && a.WriteConcern != nil {
78 t, data, err := a.WriteConcern.MarshalBSONValue()
79 if err != nil {
80 return nil, err
81 }
82 var xval bsonx.Val
83 err = xval.UnmarshalBSONValue(t, data)
84 if err != nil {
85 return nil, err
86 }
87 command = append(command, bsonx.Elem{Key: "writeConcern", Value: xval})
88 }
89
90 return &Read{
91 DB: a.NS.DB,
92 Command: command,
93 ReadPref: a.ReadPref,
94 ReadConcern: a.ReadConcern,
95 Clock: a.Clock,
96 Session: a.Session,
97 }, nil
98}
99
100// HasDollarOut returns true if the Pipeline field contains a $out stage.
101func (a *Aggregate) HasDollarOut() bool {
102 if a.Pipeline == nil {
103 return false
104 }
105 if len(a.Pipeline) == 0 {
106 return false
107 }
108
109 val := a.Pipeline[len(a.Pipeline)-1]
110
111 doc, ok := val.DocumentOK()
112 if !ok || len(doc) != 1 {
113 return false
114 }
115 return doc[0].Key == "$out"
116}
117
118// Decode will decode the wire message using the provided server description. Errors during decoding
119// are deferred until either the Result or Err methods are called.
120func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate {
121 rdr, err := (&Read{}).Decode(desc, wm).Result()
122 if err != nil {
123 a.err = err
124 return a
125 }
126
127 return a.decode(desc, rdr)
128}
129
130func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate {
131 a.result = rdr
132 return a
133}
134
135// Result returns the result of a decoded wire message and server description.
136func (a *Aggregate) Result() (bson.Raw, error) {
137 if a.err != nil {
138 return nil, a.err
139 }
140 return a.result, nil
141}
142
143// Err returns the error set on this command.
144func (a *Aggregate) Err() error { return a.err }
145
146// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
147func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
148 cmd, err := a.encode(desc)
149 if err != nil {
150 return nil, err
151 }
152
153 rdr, err := cmd.RoundTrip(ctx, desc, rw)
154 if err != nil {
155 return nil, err
156 }
157
158 return a.decode(desc, rdr).Result()
159}