blob: 2400baceac6d5d086441a2f46a2c3c60598c6ba1 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10 "errors"
11 "fmt"
12 "io"
13 "math"
14 "strconv"
15 "sync"
16
17 "github.com/mongodb/mongo-go-driver/bson/bsontype"
18 "github.com/mongodb/mongo-go-driver/bson/primitive"
19 "github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
20)
21
22var _ ValueWriter = (*valueWriter)(nil)
23
24var vwPool = sync.Pool{
25 New: func() interface{} {
26 return new(valueWriter)
27 },
28}
29
30// BSONValueWriterPool is a pool for BSON ValueWriters.
31type BSONValueWriterPool struct {
32 pool sync.Pool
33}
34
35// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
36func NewBSONValueWriterPool() *BSONValueWriterPool {
37 return &BSONValueWriterPool{
38 pool: sync.Pool{
39 New: func() interface{} {
40 return new(valueWriter)
41 },
42 },
43 }
44}
45
46// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
47func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
48 vw := bvwp.pool.Get().(*valueWriter)
49 if writer, ok := w.(*SliceWriter); ok {
50 vw.reset(*writer)
51 vw.w = writer
52 return vw
53 }
54 vw.buf = vw.buf[:0]
55 vw.w = w
56 return vw
57}
58
59// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
60// happens and ok will be false.
61func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
62 bvw, ok := vw.(*valueWriter)
63 if !ok {
64 return false
65 }
66
67 if _, ok := bvw.w.(*SliceWriter); ok {
68 bvw.buf = nil
69 }
70 bvw.w = nil
71
72 bvwp.pool.Put(bvw)
73 return true
74}
75
76// This is here so that during testing we can change it and not require
77// allocating a 4GB slice.
78var maxSize = math.MaxInt32
79
80var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
81
82type errMaxDocumentSizeExceeded struct {
83 size int64
84}
85
86func (mdse errMaxDocumentSizeExceeded) Error() string {
87 return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
88}
89
90type vwMode int
91
92const (
93 _ vwMode = iota
94 vwTopLevel
95 vwDocument
96 vwArray
97 vwValue
98 vwElement
99 vwCodeWithScope
100)
101
102func (vm vwMode) String() string {
103 var str string
104
105 switch vm {
106 case vwTopLevel:
107 str = "TopLevel"
108 case vwDocument:
109 str = "DocumentMode"
110 case vwArray:
111 str = "ArrayMode"
112 case vwValue:
113 str = "ValueMode"
114 case vwElement:
115 str = "ElementMode"
116 case vwCodeWithScope:
117 str = "CodeWithScopeMode"
118 default:
119 str = "UnknownMode"
120 }
121
122 return str
123}
124
125type vwState struct {
126 mode mode
127 key string
128 arrkey int
129 start int32
130}
131
132type valueWriter struct {
133 w io.Writer
134 buf []byte
135
136 stack []vwState
137 frame int64
138}
139
140func (vw *valueWriter) advanceFrame() {
141 if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
142 length := len(vw.stack)
143 if length+1 >= cap(vw.stack) {
144 // double it
145 buf := make([]vwState, 2*cap(vw.stack)+1)
146 copy(buf, vw.stack)
147 vw.stack = buf
148 }
149 vw.stack = vw.stack[:length+1]
150 }
151 vw.frame++
152}
153
154func (vw *valueWriter) push(m mode) {
155 vw.advanceFrame()
156
157 // Clean the stack
158 vw.stack[vw.frame].mode = m
159 vw.stack[vw.frame].key = ""
160 vw.stack[vw.frame].arrkey = 0
161 vw.stack[vw.frame].start = 0
162
163 vw.stack[vw.frame].mode = m
164 switch m {
165 case mDocument, mArray, mCodeWithScope:
166 vw.reserveLength()
167 }
168}
169
170func (vw *valueWriter) reserveLength() {
171 vw.stack[vw.frame].start = int32(len(vw.buf))
172 vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
173}
174
175func (vw *valueWriter) pop() {
176 switch vw.stack[vw.frame].mode {
177 case mElement, mValue:
178 vw.frame--
179 case mDocument, mArray, mCodeWithScope:
180 vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
181 }
182}
183
184// NewBSONValueWriter creates a ValueWriter that writes BSON to w.
185//
186// This ValueWriter will only write entire documents to the io.Writer and it
187// will buffer the document as it is built.
188func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
189 if w == nil {
190 return nil, errNilWriter
191 }
192 return newValueWriter(w), nil
193}
194
195func newValueWriter(w io.Writer) *valueWriter {
196 vw := new(valueWriter)
197 stack := make([]vwState, 1, 5)
198 stack[0] = vwState{mode: mTopLevel}
199 vw.w = w
200 vw.stack = stack
201
202 return vw
203}
204
205func newValueWriterFromSlice(buf []byte) *valueWriter {
206 vw := new(valueWriter)
207 stack := make([]vwState, 1, 5)
208 stack[0] = vwState{mode: mTopLevel}
209 vw.stack = stack
210 vw.buf = buf
211
212 return vw
213}
214
215func (vw *valueWriter) reset(buf []byte) {
216 if vw.stack == nil {
217 vw.stack = make([]vwState, 1, 5)
218 }
219 vw.stack = vw.stack[:1]
220 vw.stack[0] = vwState{mode: mTopLevel}
221 vw.buf = buf
222 vw.frame = 0
223 vw.w = nil
224}
225
226func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
227 te := TransitionError{
228 name: name,
229 current: vw.stack[vw.frame].mode,
230 destination: destination,
231 modes: modes,
232 action: "write",
233 }
234 if vw.frame != 0 {
235 te.parent = vw.stack[vw.frame-1].mode
236 }
237 return te
238}
239
240func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
241 switch vw.stack[vw.frame].mode {
242 case mElement:
243 vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key)
244 case mValue:
245 // TODO: Do this with a cache of the first 1000 or so array keys.
246 vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
247 default:
248 modes := []mode{mElement, mValue}
249 if addmodes != nil {
250 modes = append(modes, addmodes...)
251 }
252 return vw.invalidTransitionError(destination, callerName, modes)
253 }
254
255 return nil
256}
257
258func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
259 if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
260 return err
261 }
262 vw.buf = append(vw.buf, b...)
263 vw.pop()
264 return nil
265}
266
267func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
268 if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
269 return nil, err
270 }
271
272 vw.push(mArray)
273
274 return vw, nil
275}
276
277func (vw *valueWriter) WriteBinary(b []byte) error {
278 return vw.WriteBinaryWithSubtype(b, 0x00)
279}
280
281func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
282 if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
283 return err
284 }
285
286 vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
287 vw.pop()
288 return nil
289}
290
291func (vw *valueWriter) WriteBoolean(b bool) error {
292 if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
293 return err
294 }
295
296 vw.buf = bsoncore.AppendBoolean(vw.buf, b)
297 vw.pop()
298 return nil
299}
300
301func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
302 if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
303 return nil, err
304 }
305
306 // CodeWithScope is a different than other types because we need an extra
307 // frame on the stack. In the EndDocument code, we write the document
308 // length, pop, write the code with scope length, and pop. To simplify the
309 // pop code, we push a spacer frame that we'll always jump over.
310 vw.push(mCodeWithScope)
311 vw.buf = bsoncore.AppendString(vw.buf, code)
312 vw.push(mSpacer)
313 vw.push(mDocument)
314
315 return vw, nil
316}
317
318func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
319 if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
320 return err
321 }
322
323 vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
324 vw.pop()
325 return nil
326}
327
328func (vw *valueWriter) WriteDateTime(dt int64) error {
329 if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
330 return err
331 }
332
333 vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
334 vw.pop()
335 return nil
336}
337
338func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
339 if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
340 return err
341 }
342
343 vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
344 vw.pop()
345 return nil
346}
347
348func (vw *valueWriter) WriteDouble(f float64) error {
349 if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
350 return err
351 }
352
353 vw.buf = bsoncore.AppendDouble(vw.buf, f)
354 vw.pop()
355 return nil
356}
357
358func (vw *valueWriter) WriteInt32(i32 int32) error {
359 if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
360 return err
361 }
362
363 vw.buf = bsoncore.AppendInt32(vw.buf, i32)
364 vw.pop()
365 return nil
366}
367
368func (vw *valueWriter) WriteInt64(i64 int64) error {
369 if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
370 return err
371 }
372
373 vw.buf = bsoncore.AppendInt64(vw.buf, i64)
374 vw.pop()
375 return nil
376}
377
378func (vw *valueWriter) WriteJavascript(code string) error {
379 if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
380 return err
381 }
382
383 vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
384 vw.pop()
385 return nil
386}
387
388func (vw *valueWriter) WriteMaxKey() error {
389 if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
390 return err
391 }
392
393 vw.pop()
394 return nil
395}
396
397func (vw *valueWriter) WriteMinKey() error {
398 if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
399 return err
400 }
401
402 vw.pop()
403 return nil
404}
405
406func (vw *valueWriter) WriteNull() error {
407 if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
408 return err
409 }
410
411 vw.pop()
412 return nil
413}
414
415func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
416 if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
417 return err
418 }
419
420 vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
421 vw.pop()
422 return nil
423}
424
425func (vw *valueWriter) WriteRegex(pattern string, options string) error {
426 if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
427 return err
428 }
429
430 vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
431 vw.pop()
432 return nil
433}
434
435func (vw *valueWriter) WriteString(s string) error {
436 if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
437 return err
438 }
439
440 vw.buf = bsoncore.AppendString(vw.buf, s)
441 vw.pop()
442 return nil
443}
444
445func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
446 if vw.stack[vw.frame].mode == mTopLevel {
447 vw.reserveLength()
448 return vw, nil
449 }
450 if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
451 return nil, err
452 }
453
454 vw.push(mDocument)
455 return vw, nil
456}
457
458func (vw *valueWriter) WriteSymbol(symbol string) error {
459 if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
460 return err
461 }
462
463 vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
464 vw.pop()
465 return nil
466}
467
468func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
469 if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
470 return err
471 }
472
473 vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
474 vw.pop()
475 return nil
476}
477
478func (vw *valueWriter) WriteUndefined() error {
479 if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
480 return err
481 }
482
483 vw.pop()
484 return nil
485}
486
487func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
488 switch vw.stack[vw.frame].mode {
489 case mTopLevel, mDocument:
490 default:
491 return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
492 }
493
494 vw.push(mElement)
495 vw.stack[vw.frame].key = key
496
497 return vw, nil
498}
499
500func (vw *valueWriter) WriteDocumentEnd() error {
501 switch vw.stack[vw.frame].mode {
502 case mTopLevel, mDocument:
503 default:
504 return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
505 }
506
507 vw.buf = append(vw.buf, 0x00)
508
509 err := vw.writeLength()
510 if err != nil {
511 return err
512 }
513
514 if vw.stack[vw.frame].mode == mTopLevel {
515 if vw.w != nil {
516 if sw, ok := vw.w.(*SliceWriter); ok {
517 *sw = vw.buf
518 } else {
519 _, err = vw.w.Write(vw.buf)
520 if err != nil {
521 return err
522 }
523 // reset buffer
524 vw.buf = vw.buf[:0]
525 }
526 }
527 }
528
529 vw.pop()
530
531 if vw.stack[vw.frame].mode == mCodeWithScope {
532 // We ignore the error here because of the gaurantee of writeLength.
533 // See the docs for writeLength for more info.
534 _ = vw.writeLength()
535 vw.pop()
536 }
537 return nil
538}
539
540func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
541 if vw.stack[vw.frame].mode != mArray {
542 return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
543 }
544
545 arrkey := vw.stack[vw.frame].arrkey
546 vw.stack[vw.frame].arrkey++
547
548 vw.push(mValue)
549 vw.stack[vw.frame].arrkey = arrkey
550
551 return vw, nil
552}
553
554func (vw *valueWriter) WriteArrayEnd() error {
555 if vw.stack[vw.frame].mode != mArray {
556 return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
557 }
558
559 vw.buf = append(vw.buf, 0x00)
560
561 err := vw.writeLength()
562 if err != nil {
563 return err
564 }
565
566 vw.pop()
567 return nil
568}
569
570// NOTE: We assume that if we call writeLength more than once the same function
571// within the same function without altering the vw.buf that this method will
572// not return an error. If this changes ensure that the following methods are
573// updated:
574//
575// - WriteDocumentEnd
576func (vw *valueWriter) writeLength() error {
577 length := len(vw.buf)
578 if length > maxSize {
579 return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
580 }
581 length = length - int(vw.stack[vw.frame].start)
582 start := vw.stack[vw.frame].start
583
584 vw.buf[start+0] = byte(length)
585 vw.buf[start+1] = byte(length >> 8)
586 vw.buf[start+2] = byte(length >> 16)
587 vw.buf[start+3] = byte(length >> 24)
588 return nil
589}