blob: b34f07d0aa1854d481fbc5701aedc6322a8d4e41 [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonx
8
9import (
10 "errors"
11 "fmt"
12 "reflect"
13
14 "github.com/mongodb/mongo-go-driver/bson/bsoncodec"
15 "github.com/mongodb/mongo-go-driver/bson/bsonrw"
16 "github.com/mongodb/mongo-go-driver/bson/bsontype"
17)
18
19var primitiveCodecs PrimitiveCodecs
20
21var tDocument = reflect.TypeOf((Doc)(nil))
22var tMDoc = reflect.TypeOf((MDoc)(nil))
23var tArray = reflect.TypeOf((Arr)(nil))
24var tValue = reflect.TypeOf(Val{})
25var tElementSlice = reflect.TypeOf(([]Elem)(nil))
26
27// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
28// defined in this package.
29type PrimitiveCodecs struct{}
30
31// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
32// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
33func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
34 if rb == nil {
35 panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
36 }
37
38 rb.
39 RegisterEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)).
40 RegisterEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)).
41 RegisterEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)).
42 RegisterEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)).
43 RegisterDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)).
44 RegisterDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)).
45 RegisterDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)).
46 RegisterDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue))
47}
48
49// DocumentEncodeValue is the ValueEncoderFunc for *Document.
50func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
51 if !val.IsValid() || val.Type() != tDocument {
52 return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
53 }
54
55 if val.IsNil() {
56 return vw.WriteNull()
57 }
58
59 doc := val.Interface().(Doc)
60
61 dw, err := vw.WriteDocument()
62 if err != nil {
63 return err
64 }
65
66 return pc.encodeDocument(ec, dw, doc)
67}
68
69// DocumentDecodeValue is the ValueDecoderFunc for *Document.
70func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
71 if !val.CanSet() || val.Type() != tDocument {
72 return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
73 }
74
75 return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc))
76}
77
78func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error {
79
80 dr, err := vr.ReadDocument()
81 if err != nil {
82 return err
83 }
84
85 return pc.decodeDocument(dctx, dr, doc)
86}
87
88// ArrayEncodeValue is the ValueEncoderFunc for *Array.
89func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
90 if !val.IsValid() || val.Type() != tArray {
91 return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
92 }
93
94 if val.IsNil() {
95 return vw.WriteNull()
96 }
97
98 arr := val.Interface().(Arr)
99
100 aw, err := vw.WriteArray()
101 if err != nil {
102 return err
103 }
104
105 for _, val := range arr {
106 dvw, err := aw.WriteArrayElement()
107 if err != nil {
108 return err
109 }
110
111 err = pc.encodeValue(ec, dvw, val)
112
113 if err != nil {
114 return err
115 }
116 }
117
118 return aw.WriteArrayEnd()
119}
120
121// ArrayDecodeValue is the ValueDecoderFunc for *Array.
122func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
123 if !val.CanSet() || val.Type() != tArray {
124 return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
125 }
126
127 ar, err := vr.ReadArray()
128 if err != nil {
129 return err
130 }
131
132 if val.IsNil() {
133 val.Set(reflect.MakeSlice(tArray, 0, 0))
134 }
135 val.SetLen(0)
136
137 for {
138 vr, err := ar.ReadValue()
139 if err == bsonrw.ErrEOA {
140 break
141 }
142 if err != nil {
143 return err
144 }
145
146 var elem Val
147 err = pc.valueDecodeValue(dc, vr, &elem)
148 if err != nil {
149 return err
150 }
151
152 val.Set(reflect.Append(val, reflect.ValueOf(elem)))
153 }
154
155 return nil
156}
157
158// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element.
159func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
160 if !val.IsValid() || val.Type() != tElementSlice {
161 return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
162 }
163
164 if val.IsNil() {
165 return vw.WriteNull()
166 }
167
168 return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument))
169}
170
171// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element.
172func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
173 if !val.CanSet() || val.Type() != tElementSlice {
174 return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
175 }
176
177 if val.IsNil() {
178 val.Set(reflect.MakeSlice(val.Type(), 0, 0))
179 }
180
181 val.SetLen(0)
182
183 dr, err := vr.ReadDocument()
184 if err != nil {
185 return err
186 }
187 elems := make([]reflect.Value, 0)
188 for {
189 key, vr, err := dr.ReadElement()
190 if err == bsonrw.ErrEOD {
191 break
192 }
193 if err != nil {
194 return err
195 }
196
197 var elem Elem
198 err = pc.elementDecodeValue(dc, vr, key, &elem)
199 if err != nil {
200 return err
201 }
202
203 elems = append(elems, reflect.ValueOf(elem))
204 }
205
206 val.Set(reflect.Append(val, elems...))
207 return nil
208}
209
210// ValueEncodeValue is the ValueEncoderFunc for *Value.
211func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
212 if !val.IsValid() || val.Type() != tValue {
213 return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val}
214 }
215
216 v := val.Interface().(Val)
217
218 return pc.encodeValue(ec, vw, v)
219}
220
221// ValueDecodeValue is the ValueDecoderFunc for *Value.
222func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
223 if !val.CanSet() || val.Type() != tValue {
224 return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val}
225 }
226
227 return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val))
228}
229
230// encodeDocument is a separate function that we use because CodeWithScope
231// returns us a DocumentWriter and we need to do the same logic that we would do
232// for a document but cannot use a Codec.
233func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error {
234 for _, elem := range doc {
235 dvw, err := dw.WriteDocumentElement(elem.Key)
236 if err != nil {
237 return err
238 }
239
240 err = pc.encodeValue(ec, dvw, elem.Value)
241
242 if err != nil {
243 return err
244 }
245 }
246
247 return dw.WriteDocumentEnd()
248}
249
250// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader.
251func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
252 return pc.decodeDocument(dctx, dr, pdoc)
253}
254
255func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
256 if *pdoc == nil {
257 *pdoc = make(Doc, 0)
258 }
259 *pdoc = (*pdoc)[:0]
260 for {
261 key, vr, err := dr.ReadElement()
262 if err == bsonrw.ErrEOD {
263 break
264 }
265 if err != nil {
266 return err
267 }
268
269 var elem Elem
270 err = pc.elementDecodeValue(dctx, vr, key, &elem)
271 if err != nil {
272 return err
273 }
274
275 *pdoc = append(*pdoc, elem)
276 }
277 return nil
278}
279
280func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error {
281 var val Val
282 switch vr.Type() {
283 case bsontype.Double:
284 f64, err := vr.ReadDouble()
285 if err != nil {
286 return err
287 }
288 val = Double(f64)
289 case bsontype.String:
290 str, err := vr.ReadString()
291 if err != nil {
292 return err
293 }
294 val = String(str)
295 case bsontype.EmbeddedDocument:
296 var embeddedDoc Doc
297 err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
298 if err != nil {
299 return err
300 }
301 val = Document(embeddedDoc)
302 case bsontype.Array:
303 arr := reflect.New(tArray).Elem()
304 err := pc.ArrayDecodeValue(dc, vr, arr)
305 if err != nil {
306 return err
307 }
308 val = Array(arr.Interface().(Arr))
309 case bsontype.Binary:
310 data, subtype, err := vr.ReadBinary()
311 if err != nil {
312 return err
313 }
314 val = Binary(subtype, data)
315 case bsontype.Undefined:
316 err := vr.ReadUndefined()
317 if err != nil {
318 return err
319 }
320 val = Undefined()
321 case bsontype.ObjectID:
322 oid, err := vr.ReadObjectID()
323 if err != nil {
324 return err
325 }
326 val = ObjectID(oid)
327 case bsontype.Boolean:
328 b, err := vr.ReadBoolean()
329 if err != nil {
330 return err
331 }
332 val = Boolean(b)
333 case bsontype.DateTime:
334 dt, err := vr.ReadDateTime()
335 if err != nil {
336 return err
337 }
338 val = DateTime(dt)
339 case bsontype.Null:
340 err := vr.ReadNull()
341 if err != nil {
342 return err
343 }
344 val = Null()
345 case bsontype.Regex:
346 pattern, options, err := vr.ReadRegex()
347 if err != nil {
348 return err
349 }
350 val = Regex(pattern, options)
351 case bsontype.DBPointer:
352 ns, pointer, err := vr.ReadDBPointer()
353 if err != nil {
354 return err
355 }
356 val = DBPointer(ns, pointer)
357 case bsontype.JavaScript:
358 js, err := vr.ReadJavascript()
359 if err != nil {
360 return err
361 }
362 val = JavaScript(js)
363 case bsontype.Symbol:
364 symbol, err := vr.ReadSymbol()
365 if err != nil {
366 return err
367 }
368 val = Symbol(symbol)
369 case bsontype.CodeWithScope:
370 code, scope, err := vr.ReadCodeWithScope()
371 if err != nil {
372 return err
373 }
374 var doc Doc
375 err = pc.decodeDocument(dc, scope, &doc)
376 if err != nil {
377 return err
378 }
379 val = CodeWithScope(code, doc)
380 case bsontype.Int32:
381 i32, err := vr.ReadInt32()
382 if err != nil {
383 return err
384 }
385 val = Int32(i32)
386 case bsontype.Timestamp:
387 t, i, err := vr.ReadTimestamp()
388 if err != nil {
389 return err
390 }
391 val = Timestamp(t, i)
392 case bsontype.Int64:
393 i64, err := vr.ReadInt64()
394 if err != nil {
395 return err
396 }
397 val = Int64(i64)
398 case bsontype.Decimal128:
399 d128, err := vr.ReadDecimal128()
400 if err != nil {
401 return err
402 }
403 val = Decimal128(d128)
404 case bsontype.MinKey:
405 err := vr.ReadMinKey()
406 if err != nil {
407 return err
408 }
409 val = MinKey()
410 case bsontype.MaxKey:
411 err := vr.ReadMaxKey()
412 if err != nil {
413 return err
414 }
415 val = MaxKey()
416 default:
417 return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
418 }
419
420 *elem = Elem{Key: key, Value: val}
421 return nil
422}
423
424// encodeValue does not validation, and the callers must perform validation on val before calling
425// this method.
426func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error {
427 var err error
428 switch val.Type() {
429 case bsontype.Double:
430 err = vw.WriteDouble(val.Double())
431 case bsontype.String:
432 err = vw.WriteString(val.StringValue())
433 case bsontype.EmbeddedDocument:
434 var encoder bsoncodec.ValueEncoder
435 encoder, err = ec.LookupEncoder(tDocument)
436 if err != nil {
437 break
438 }
439 err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document()))
440 case bsontype.Array:
441 var encoder bsoncodec.ValueEncoder
442 encoder, err = ec.LookupEncoder(tArray)
443 if err != nil {
444 break
445 }
446 err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array()))
447 case bsontype.Binary:
448 // TODO: FIX THIS (╯°□°)╯︵ ┻━┻
449 subtype, data := val.Binary()
450 err = vw.WriteBinaryWithSubtype(data, subtype)
451 case bsontype.Undefined:
452 err = vw.WriteUndefined()
453 case bsontype.ObjectID:
454 err = vw.WriteObjectID(val.ObjectID())
455 case bsontype.Boolean:
456 err = vw.WriteBoolean(val.Boolean())
457 case bsontype.DateTime:
458 err = vw.WriteDateTime(val.DateTime())
459 case bsontype.Null:
460 err = vw.WriteNull()
461 case bsontype.Regex:
462 err = vw.WriteRegex(val.Regex())
463 case bsontype.DBPointer:
464 err = vw.WriteDBPointer(val.DBPointer())
465 case bsontype.JavaScript:
466 err = vw.WriteJavascript(val.JavaScript())
467 case bsontype.Symbol:
468 err = vw.WriteSymbol(val.Symbol())
469 case bsontype.CodeWithScope:
470 code, scope := val.CodeWithScope()
471
472 var cwsw bsonrw.DocumentWriter
473 cwsw, err = vw.WriteCodeWithScope(code)
474 if err != nil {
475 break
476 }
477
478 err = pc.encodeDocument(ec, cwsw, scope)
479 case bsontype.Int32:
480 err = vw.WriteInt32(val.Int32())
481 case bsontype.Timestamp:
482 err = vw.WriteTimestamp(val.Timestamp())
483 case bsontype.Int64:
484 err = vw.WriteInt64(val.Int64())
485 case bsontype.Decimal128:
486 err = vw.WriteDecimal128(val.Decimal128())
487 case bsontype.MinKey:
488 err = vw.WriteMinKey()
489 case bsontype.MaxKey:
490 err = vw.WriteMaxKey()
491 default:
492 err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type())
493 }
494
495 return err
496}
497
498func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error {
499 switch vr.Type() {
500 case bsontype.Double:
501 f64, err := vr.ReadDouble()
502 if err != nil {
503 return err
504 }
505 *val = Double(f64)
506 case bsontype.String:
507 str, err := vr.ReadString()
508 if err != nil {
509 return err
510 }
511 *val = String(str)
512 case bsontype.EmbeddedDocument:
513 var embeddedDoc Doc
514 err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
515 if err != nil {
516 return err
517 }
518 *val = Document(embeddedDoc)
519 case bsontype.Array:
520 arr := reflect.New(tArray).Elem()
521 err := pc.ArrayDecodeValue(dc, vr, arr)
522 if err != nil {
523 return err
524 }
525 *val = Array(arr.Interface().(Arr))
526 case bsontype.Binary:
527 data, subtype, err := vr.ReadBinary()
528 if err != nil {
529 return err
530 }
531 *val = Binary(subtype, data)
532 case bsontype.Undefined:
533 err := vr.ReadUndefined()
534 if err != nil {
535 return err
536 }
537 *val = Undefined()
538 case bsontype.ObjectID:
539 oid, err := vr.ReadObjectID()
540 if err != nil {
541 return err
542 }
543 *val = ObjectID(oid)
544 case bsontype.Boolean:
545 b, err := vr.ReadBoolean()
546 if err != nil {
547 return err
548 }
549 *val = Boolean(b)
550 case bsontype.DateTime:
551 dt, err := vr.ReadDateTime()
552 if err != nil {
553 return err
554 }
555 *val = DateTime(dt)
556 case bsontype.Null:
557 err := vr.ReadNull()
558 if err != nil {
559 return err
560 }
561 *val = Null()
562 case bsontype.Regex:
563 pattern, options, err := vr.ReadRegex()
564 if err != nil {
565 return err
566 }
567 *val = Regex(pattern, options)
568 case bsontype.DBPointer:
569 ns, pointer, err := vr.ReadDBPointer()
570 if err != nil {
571 return err
572 }
573 *val = DBPointer(ns, pointer)
574 case bsontype.JavaScript:
575 js, err := vr.ReadJavascript()
576 if err != nil {
577 return err
578 }
579 *val = JavaScript(js)
580 case bsontype.Symbol:
581 symbol, err := vr.ReadSymbol()
582 if err != nil {
583 return err
584 }
585 *val = Symbol(symbol)
586 case bsontype.CodeWithScope:
587 code, scope, err := vr.ReadCodeWithScope()
588 if err != nil {
589 return err
590 }
591 var scopeDoc Doc
592 err = pc.decodeDocument(dc, scope, &scopeDoc)
593 if err != nil {
594 return err
595 }
596 *val = CodeWithScope(code, scopeDoc)
597 case bsontype.Int32:
598 i32, err := vr.ReadInt32()
599 if err != nil {
600 return err
601 }
602 *val = Int32(i32)
603 case bsontype.Timestamp:
604 t, i, err := vr.ReadTimestamp()
605 if err != nil {
606 return err
607 }
608 *val = Timestamp(t, i)
609 case bsontype.Int64:
610 i64, err := vr.ReadInt64()
611 if err != nil {
612 return err
613 }
614 *val = Int64(i64)
615 case bsontype.Decimal128:
616 d128, err := vr.ReadDecimal128()
617 if err != nil {
618 return err
619 }
620 *val = Decimal128(d128)
621 case bsontype.MinKey:
622 err := vr.ReadMinKey()
623 if err != nil {
624 return err
625 }
626 *val = MinKey()
627 case bsontype.MaxKey:
628 err := vr.ReadMaxKey()
629 if err != nil {
630 return err
631 }
632 *val = MaxKey()
633 default:
634 return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
635 }
636
637 return nil
638}