blob: 568a3ecfd994a619e4cf9e8e607f23e4ec17402c [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10 "errors"
11 "fmt"
12 "strings"
13
14 "github.com/mongodb/mongo-go-driver/bson"
15 "github.com/mongodb/mongo-go-driver/mongo/writeconcern"
16 "github.com/mongodb/mongo-go-driver/x/bsonx"
17)
18
19// Query represents the OP_QUERY message of the MongoDB wire protocol.
20type Query struct {
21 MsgHeader Header
22 Flags QueryFlag
23 FullCollectionName string
24 NumberToSkip int32
25 NumberToReturn int32
26 Query bson.Raw
27 ReturnFieldsSelector bson.Raw
28
29 SkipSet bool
30 Limit *int32
31 BatchSize *int32
32}
33
34var optionsMap = map[string]string{
35 "$orderby": "sort",
36 "$hint": "hint",
37 "$comment": "comment",
38 "$maxScan": "maxScan",
39 "$max": "max",
40 "$min": "min",
41 "$returnKey": "returnKey",
42 "$showDiskLoc": "showRecordId",
43 "$maxTimeMS": "maxTimeMS",
44 "$snapshot": "snapshot",
45}
46
47// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
48//
49// See AppendWireMessage for a description of the rules this method follows.
50func (q Query) MarshalWireMessage() ([]byte, error) {
51 b := make([]byte, 0, q.Len())
52 return q.AppendWireMessage(b)
53}
54
55// ValidateWireMessage implements the Validator and WireMessage interfaces.
56func (q Query) ValidateWireMessage() error {
57 if int(q.MsgHeader.MessageLength) != q.Len() {
58 return errors.New("incorrect header: message length is not correct")
59 }
60 if q.MsgHeader.OpCode != OpQuery {
61 return errors.New("incorrect header: op code is not OpQuery")
62 }
63 if strings.Index(q.FullCollectionName, ".") == -1 {
64 return errors.New("incorrect header: collection name does not contain a dot")
65 }
66 if q.Query != nil && len(q.Query) > 0 {
67 err := q.Query.Validate()
68 if err != nil {
69 return err
70 }
71 }
72
73 if q.ReturnFieldsSelector != nil && len(q.ReturnFieldsSelector) > 0 {
74 err := q.ReturnFieldsSelector.Validate()
75 if err != nil {
76 return err
77 }
78 }
79
80 return nil
81}
82
83// AppendWireMessage implements the Appender and WireMessage interfaces.
84//
85// AppendWireMessage will set the MessageLength property of the MsgHeader
86// if it is zero. It will also set the OpCode to OpQuery if the OpCode is
87// zero. If either of these properties are non-zero and not correct, this
88// method will return both the []byte with the wire message appended to it
89// and an invalid header error.
90func (q Query) AppendWireMessage(b []byte) ([]byte, error) {
91 var err error
92 err = q.MsgHeader.SetDefaults(q.Len(), OpQuery)
93
94 b = q.MsgHeader.AppendHeader(b)
95 b = appendInt32(b, int32(q.Flags))
96 b = appendCString(b, q.FullCollectionName)
97 b = appendInt32(b, q.NumberToSkip)
98 b = appendInt32(b, q.NumberToReturn)
99 b = append(b, q.Query...)
100 b = append(b, q.ReturnFieldsSelector...)
101 return b, err
102}
103
104// String implements the fmt.Stringer interface.
105func (q Query) String() string {
106 return fmt.Sprintf(
107 `OP_QUERY{MsgHeader: %s, Flags: %s, FullCollectionname: %s, NumberToSkip: %d, NumberToReturn: %d, Query: %s, ReturnFieldsSelector: %s}`,
108 q.MsgHeader, q.Flags, q.FullCollectionName, q.NumberToSkip, q.NumberToReturn, q.Query, q.ReturnFieldsSelector,
109 )
110}
111
112// Len implements the WireMessage interface.
113func (q Query) Len() int {
114 // Header + Flags + CollectionName + Null Byte + Skip + Return + Query + ReturnFieldsSelector
115 return 16 + 4 + len(q.FullCollectionName) + 1 + 4 + 4 + len(q.Query) + len(q.ReturnFieldsSelector)
116}
117
118// UnmarshalWireMessage implements the Unmarshaler interface.
119func (q *Query) UnmarshalWireMessage(b []byte) error {
120 var err error
121 q.MsgHeader, err = ReadHeader(b, 0)
122 if err != nil {
123 return err
124 }
125 if len(b) < int(q.MsgHeader.MessageLength) {
126 return Error{Type: ErrOpQuery, Message: "[]byte too small"}
127 }
128
129 q.Flags = QueryFlag(readInt32(b, 16))
130 q.FullCollectionName, err = readCString(b, 20)
131 if err != nil {
132 return err
133 }
134 pos := 20 + len(q.FullCollectionName) + 1
135 q.NumberToSkip = readInt32(b, int32(pos))
136 pos += 4
137 q.NumberToReturn = readInt32(b, int32(pos))
138 pos += 4
139
140 var size int
141 var wmerr Error
142 q.Query, size, wmerr = readDocument(b, int32(pos))
143 if wmerr.Message != "" {
144 wmerr.Type = ErrOpQuery
145 return wmerr
146 }
147 pos += size
148 if pos < len(b) {
149 q.ReturnFieldsSelector, size, wmerr = readDocument(b, int32(pos))
150 if wmerr.Message != "" {
151 wmerr.Type = ErrOpQuery
152 return wmerr
153 }
154 pos += size
155 }
156
157 return nil
158}
159
160// AcknowledgedWrite returns true if this command represents an acknowledged write
161func (q *Query) AcknowledgedWrite() bool {
162 wcElem, err := q.Query.LookupErr("writeConcern")
163 if err != nil {
164 // no wc --> ack
165 return true
166 }
167
168 return writeconcern.AcknowledgedValue(wcElem)
169}
170
171// Legacy returns true if the query represents a legacy find operation.
172func (q Query) Legacy() bool {
173 return !strings.Contains(q.FullCollectionName, "$cmd")
174}
175
176// DatabaseName returns the database name for the query.
177func (q Query) DatabaseName() string {
178 if q.Legacy() {
179 return strings.Split(q.FullCollectionName, ".")[0]
180 }
181
182 return q.FullCollectionName[:len(q.FullCollectionName)-5] // remove .$cmd
183}
184
185// CollectionName returns the collection name for the query.
186func (q Query) CollectionName() string {
187 parts := strings.Split(q.FullCollectionName, ".")
188 return parts[len(parts)-1]
189}
190
191// CommandDocument creates a BSON document representing this command.
192func (q Query) CommandDocument() (bsonx.Doc, error) {
193 if q.Legacy() {
194 return q.legacyCommandDocument()
195 }
196
197 cmd, err := bsonx.ReadDoc([]byte(q.Query))
198 if err != nil {
199 return nil, err
200 }
201
202 cmdElem := cmd[0]
203 if cmdElem.Key == "$query" {
204 cmd = cmdElem.Value.Document()
205 }
206
207 return cmd, nil
208}
209
210func (q Query) legacyCommandDocument() (bsonx.Doc, error) {
211 doc, err := bsonx.ReadDoc(q.Query)
212 if err != nil {
213 return nil, err
214 }
215
216 parts := strings.Split(q.FullCollectionName, ".")
217 collName := parts[len(parts)-1]
218 doc = append(bsonx.Doc{{"find", bsonx.String(collName)}}, doc...)
219
220 var filter bsonx.Doc
221 var queryIndex int
222 for i, elem := range doc {
223 if newKey, ok := optionsMap[elem.Key]; ok {
224 doc[i].Key = newKey
225 continue
226 }
227
228 if elem.Key == "$query" {
229 filter = elem.Value.Document()
230 } else {
231 // the element is the filter
232 filter = filter.Append(elem.Key, elem.Value)
233 }
234
235 queryIndex = i
236 }
237
238 doc = append(doc[:queryIndex], doc[queryIndex+1:]...) // remove $query
239 if len(filter) != 0 {
240 doc = doc.Append("filter", bsonx.Document(filter))
241 }
242
243 doc, err = q.convertLegacyParams(doc)
244 if err != nil {
245 return nil, err
246 }
247
248 return doc, nil
249}
250
251func (q Query) convertLegacyParams(doc bsonx.Doc) (bsonx.Doc, error) {
252 if q.ReturnFieldsSelector != nil {
253 projDoc, err := bsonx.ReadDoc(q.ReturnFieldsSelector)
254 if err != nil {
255 return nil, err
256 }
257 doc = doc.Append("projection", bsonx.Document(projDoc))
258 }
259 if q.Limit != nil {
260 limit := *q.Limit
261 if limit < 0 {
262 limit *= -1
263 doc = doc.Append("singleBatch", bsonx.Boolean(true))
264 }
265
266 doc = doc.Append("limit", bsonx.Int32(*q.Limit))
267 }
268 if q.BatchSize != nil {
269 doc = doc.Append("batchSize", bsonx.Int32(*q.BatchSize))
270 }
271 if q.SkipSet {
272 doc = doc.Append("skip", bsonx.Int32(q.NumberToSkip))
273 }
274 if q.Flags&TailableCursor > 0 {
275 doc = doc.Append("tailable", bsonx.Boolean(true))
276 }
277 if q.Flags&OplogReplay > 0 {
278 doc = doc.Append("oplogReplay", bsonx.Boolean(true))
279 }
280 if q.Flags&NoCursorTimeout > 0 {
281 doc = doc.Append("noCursorTimeout", bsonx.Boolean(true))
282 }
283 if q.Flags&AwaitData > 0 {
284 doc = doc.Append("awaitData", bsonx.Boolean(true))
285 }
286 if q.Flags&Partial > 0 {
287 doc = doc.Append("allowPartialResults", bsonx.Boolean(true))
288 }
289
290 return doc, nil
291}
292
293// QueryFlag represents the flags on an OP_QUERY message.
294type QueryFlag int32
295
296// These constants represent the individual flags on an OP_QUERY message.
297const (
298 _ QueryFlag = 1 << iota
299 TailableCursor
300 SlaveOK
301 OplogReplay
302 NoCursorTimeout
303 AwaitData
304 Exhaust
305 Partial
306)
307
308// String implements the fmt.Stringer interface.
309func (qf QueryFlag) String() string {
310 strs := make([]string, 0)
311 if qf&TailableCursor == TailableCursor {
312 strs = append(strs, "TailableCursor")
313 }
314 if qf&SlaveOK == SlaveOK {
315 strs = append(strs, "SlaveOK")
316 }
317 if qf&OplogReplay == OplogReplay {
318 strs = append(strs, "OplogReplay")
319 }
320 if qf&NoCursorTimeout == NoCursorTimeout {
321 strs = append(strs, "NoCursorTimeout")
322 }
323 if qf&AwaitData == AwaitData {
324 strs = append(strs, "AwaitData")
325 }
326 if qf&Exhaust == Exhaust {
327 strs = append(strs, "Exhaust")
328 }
329 if qf&Partial == Partial {
330 strs = append(strs, "Partial")
331 }
332 str := "["
333 str += strings.Join(strs, ", ")
334 str += "]"
335 return str
336}