blob: 9757d9bcb1f6fba2c244cafe3c22b03014d71cd9 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package mongo
8
9import (
10 "context"
11 "errors"
12 "fmt"
13 "net"
14 "reflect"
15 "strings"
16
17 "github.com/mongodb/mongo-go-driver/mongo/options"
18 "github.com/mongodb/mongo-go-driver/x/bsonx"
19
20 "github.com/mongodb/mongo-go-driver/bson"
21 "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
22 "github.com/mongodb/mongo-go-driver/bson/bsontype"
23 "github.com/mongodb/mongo-go-driver/bson/primitive"
24)
25
26// Dialer is used to make network connections.
27type Dialer interface {
28 DialContext(ctx context.Context, network, address string) (net.Conn, error)
29}
30
31// BSONAppender is an interface implemented by types that can marshal a
32// provided type into BSON bytes and append those bytes to the provided []byte.
33// The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON
34// method may also write incomplete BSON to the []byte.
35type BSONAppender interface {
36 AppendBSON([]byte, interface{}) ([]byte, error)
37}
38
39// BSONAppenderFunc is an adapter function that allows any function that
40// satisfies the AppendBSON method signature to be used where a BSONAppender is
41// used.
42type BSONAppenderFunc func([]byte, interface{}) ([]byte, error)
43
44// AppendBSON implements the BSONAppender interface
45func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) {
46 return baf(dst, val)
47}
48
49// MarshalError is returned when attempting to transform a value into a document
50// results in an error.
51type MarshalError struct {
52 Value interface{}
53 Err error
54}
55
56// Error implements the error interface.
57func (me MarshalError) Error() string {
58 return fmt.Sprintf("cannot transform type %s to a *bsonx.Document", reflect.TypeOf(me.Value))
59}
60
61// Pipeline is a type that makes creating aggregation pipelines easier. It is a
62// helper and is intended for serializing to BSON.
63//
64// Example usage:
65//
66// mongo.Pipeline{
67// {{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}},
68// {{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}},
69// }
70//
71type Pipeline []bson.D
72
73// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. This will
74// be removed when we switch from using bsonx to bsoncore for the driver package.
75func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, interface{}, error) {
76 // TODO: performance is going to be pretty bad for bsonx.Doc here since we turn it into a []byte
77 // only to turn it back into a bsonx.Doc. We can fix this post beta1 when we refactor the driver
78 // package to use bsoncore.Document instead of bsonx.Doc.
79 if registry == nil {
80 registry = bson.NewRegistryBuilder().Build()
81 }
82 switch tt := val.(type) {
83 case nil:
84 return nil, nil, ErrNilDocument
85 case bsonx.Doc:
86 val = tt.Copy()
87 case []byte:
88 // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
89 val = bson.Raw(tt)
90 }
91
92 // TODO(skriptble): Use a pool of these instead.
93 buf := make([]byte, 0, 256)
94 b, err := bson.MarshalAppendWithRegistry(registry, buf, val)
95 if err != nil {
96 return nil, nil, MarshalError{Value: val, Err: err}
97 }
98
99 d, err := bsonx.ReadDoc(b)
100 if err != nil {
101 return nil, nil, err
102 }
103
104 var id interface{}
105
106 idx := d.IndexOf("_id")
107 var idElem bsonx.Elem
108 switch idx {
109 case -1:
110 idElem = bsonx.Elem{"_id", bsonx.ObjectID(primitive.NewObjectID())}
111 d = append(d, bsonx.Elem{})
112 copy(d[1:], d)
113 d[0] = idElem
114 default:
115 idElem = d[idx]
116 copy(d[1:idx+1], d[0:idx])
117 d[0] = idElem
118 }
119
120 t, data, err := idElem.Value.MarshalAppendBSONValue(buf[:0])
121 if err != nil {
122 return nil, nil, err
123 }
124
125 err = bson.RawValue{Type: t, Value: data}.UnmarshalWithRegistry(registry, &id)
126 if err != nil {
127 return nil, nil, err
128 }
129
130 return d, id, nil
131}
132
133func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc, error) {
134 if registry == nil {
135 registry = bson.NewRegistryBuilder().Build()
136 }
137 if val == nil {
138 return nil, ErrNilDocument
139 }
140 if doc, ok := val.(bsonx.Doc); ok {
141 return doc.Copy(), nil
142 }
143 if bs, ok := val.([]byte); ok {
144 // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
145 val = bson.Raw(bs)
146 }
147
148 // TODO(skriptble): Use a pool of these instead.
149 buf := make([]byte, 0, 256)
150 b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val)
151 if err != nil {
152 return nil, MarshalError{Value: val, Err: err}
153 }
154 return bsonx.ReadDoc(b)
155}
156
157func ensureID(d bsonx.Doc) (bsonx.Doc, interface{}) {
158 var id interface{}
159
160 elem, err := d.LookupElementErr("_id")
161 switch err.(type) {
162 case nil:
163 id = elem
164 default:
165 oid := primitive.NewObjectID()
166 d = append(d, bsonx.Elem{"_id", bsonx.ObjectID(oid)})
167 id = oid
168 }
169 return d, id
170}
171
172func ensureDollarKey(doc bsonx.Doc) error {
173 if len(doc) == 0 {
174 return errors.New("update document must have at least one element")
175 }
176 if !strings.HasPrefix(doc[0].Key, "$") {
177 return errors.New("update document must contain key beginning with '$'")
178 }
179 return nil
180}
181
182func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
183 pipelineArr := bsonx.Arr{}
184 switch t := pipeline.(type) {
185 case bsoncodec.ValueMarshaler:
186 btype, val, err := t.MarshalBSONValue()
187 if err != nil {
188 return nil, err
189 }
190 if btype != bsontype.Array {
191 return nil, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
192 }
193 err = pipelineArr.UnmarshalBSONValue(btype, val)
194 if err != nil {
195 return nil, err
196 }
197 default:
198 val := reflect.ValueOf(t)
199 if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
200 return nil, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
201 }
202 for idx := 0; idx < val.Len(); idx++ {
203 elem, err := transformDocument(registry, val.Index(idx).Interface())
204 if err != nil {
205 return nil, err
206 }
207 pipelineArr = append(pipelineArr, bsonx.Document(elem))
208 }
209 }
210
211 return pipelineArr, nil
212}
213
214// Build the aggregation pipeline for the CountDocument command.
215func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsonx.Arr, error) {
216 pipeline := bsonx.Arr{}
217 filterDoc, err := transformDocument(registry, filter)
218
219 if err != nil {
220 return nil, err
221 }
222 pipeline = append(pipeline, bsonx.Document(bsonx.Doc{{"$match", bsonx.Document(filterDoc)}}))
223
224 if opts != nil {
225 if opts.Skip != nil {
226 pipeline = append(pipeline, bsonx.Document(bsonx.Doc{{"$skip", bsonx.Int64(*opts.Skip)}}))
227 }
228 if opts.Limit != nil {
229 pipeline = append(pipeline, bsonx.Document(bsonx.Doc{{"$limit", bsonx.Int64(*opts.Limit)}}))
230 }
231 }
232
233 pipeline = append(pipeline, bsonx.Document(bsonx.Doc{
234 {"$group", bsonx.Document(bsonx.Doc{
235 {"_id", bsonx.Null()},
236 {"n", bsonx.Document(bsonx.Doc{{"$sum", bsonx.Int32(1)}})},
237 })},
238 },
239 ))
240
241 return pipeline, nil
242}