blob: 04fbb56ebbab8d2a9089f8b867b3791e4c84adde [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsonrw
8
9import (
10 "bytes"
11 "encoding/base64"
12 "fmt"
13 "github.com/mongodb/mongo-go-driver/bson/primitive"
14 "io"
15 "math"
16 "sort"
17 "strconv"
18 "strings"
19 "sync"
20 "time"
21 "unicode/utf8"
22)
23
24var ejvwPool = sync.Pool{
25 New: func() interface{} {
26 return new(extJSONValueWriter)
27 },
28}
29
30// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters.
31type ExtJSONValueWriterPool struct {
32 pool sync.Pool
33}
34
35// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON.
36func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool {
37 return &ExtJSONValueWriterPool{
38 pool: sync.Pool{
39 New: func() interface{} {
40 return new(extJSONValueWriter)
41 },
42 },
43 }
44}
45
46// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination.
47func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter {
48 vw := bvwp.pool.Get().(*extJSONValueWriter)
49 if writer, ok := w.(*SliceWriter); ok {
50 vw.reset(*writer, canonical, escapeHTML)
51 vw.w = writer
52 return vw
53 }
54 vw.buf = vw.buf[:0]
55 vw.w = w
56 return vw
57}
58
59// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing
60// happens and ok will be false.
61func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
62 bvw, ok := vw.(*extJSONValueWriter)
63 if !ok {
64 return false
65 }
66
67 if _, ok := bvw.w.(*SliceWriter); ok {
68 bvw.buf = nil
69 }
70 bvw.w = nil
71
72 bvwp.pool.Put(bvw)
73 return true
74}
75
76type ejvwState struct {
77 mode mode
78}
79
80type extJSONValueWriter struct {
81 w io.Writer
82 buf []byte
83
84 stack []ejvwState
85 frame int64
86 canonical bool
87 escapeHTML bool
88}
89
90// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
91func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) {
92 if w == nil {
93 return nil, errNilWriter
94 }
95
96 return newExtJSONWriter(w, canonical, escapeHTML), nil
97}
98
99func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter {
100 stack := make([]ejvwState, 1, 5)
101 stack[0] = ejvwState{mode: mTopLevel}
102
103 return &extJSONValueWriter{
104 w: w,
105 buf: []byte{},
106 stack: stack,
107 canonical: canonical,
108 escapeHTML: escapeHTML,
109 }
110}
111
112func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
113 stack := make([]ejvwState, 1, 5)
114 stack[0] = ejvwState{mode: mTopLevel}
115
116 return &extJSONValueWriter{
117 buf: buf,
118 stack: stack,
119 canonical: canonical,
120 escapeHTML: escapeHTML,
121 }
122}
123
124func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
125 if ejvw.stack == nil {
126 ejvw.stack = make([]ejvwState, 1, 5)
127 }
128
129 ejvw.stack = ejvw.stack[:1]
130 ejvw.stack[0] = ejvwState{mode: mTopLevel}
131 ejvw.canonical = canonical
132 ejvw.escapeHTML = escapeHTML
133 ejvw.frame = 0
134 ejvw.buf = buf
135 ejvw.w = nil
136}
137
138func (ejvw *extJSONValueWriter) advanceFrame() {
139 if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
140 length := len(ejvw.stack)
141 if length+1 >= cap(ejvw.stack) {
142 // double it
143 buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
144 copy(buf, ejvw.stack)
145 ejvw.stack = buf
146 }
147 ejvw.stack = ejvw.stack[:length+1]
148 }
149 ejvw.frame++
150}
151
152func (ejvw *extJSONValueWriter) push(m mode) {
153 ejvw.advanceFrame()
154
155 ejvw.stack[ejvw.frame].mode = m
156}
157
158func (ejvw *extJSONValueWriter) pop() {
159 switch ejvw.stack[ejvw.frame].mode {
160 case mElement, mValue:
161 ejvw.frame--
162 case mDocument, mArray, mCodeWithScope:
163 ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
164 }
165}
166
167func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
168 te := TransitionError{
169 name: name,
170 current: ejvw.stack[ejvw.frame].mode,
171 destination: destination,
172 modes: modes,
173 action: "write",
174 }
175 if ejvw.frame != 0 {
176 te.parent = ejvw.stack[ejvw.frame-1].mode
177 }
178 return te
179}
180
181func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
182 switch ejvw.stack[ejvw.frame].mode {
183 case mElement, mValue:
184 default:
185 modes := []mode{mElement, mValue}
186 if addmodes != nil {
187 modes = append(modes, addmodes...)
188 }
189 return ejvw.invalidTransitionErr(destination, callerName, modes)
190 }
191
192 return nil
193}
194
195func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
196 var s string
197 if quotes {
198 s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
199 } else {
200 s = fmt.Sprintf(`{"$%s":%s}`, key, value)
201 }
202
203 ejvw.buf = append(ejvw.buf, []byte(s)...)
204}
205
206func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
207 if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
208 return nil, err
209 }
210
211 ejvw.buf = append(ejvw.buf, '[')
212
213 ejvw.push(mArray)
214 return ejvw, nil
215}
216
217func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
218 return ejvw.WriteBinaryWithSubtype(b, 0x00)
219}
220
221func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
222 if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
223 return err
224 }
225
226 var buf bytes.Buffer
227 buf.WriteString(`{"$binary":{"base64":"`)
228 buf.WriteString(base64.StdEncoding.EncodeToString(b))
229 buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
230
231 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
232
233 ejvw.pop()
234 return nil
235}
236
237func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
238 if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
239 return err
240 }
241
242 ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
243 ejvw.buf = append(ejvw.buf, ',')
244
245 ejvw.pop()
246 return nil
247}
248
249func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
250 if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
251 return nil, err
252 }
253
254 var buf bytes.Buffer
255 buf.WriteString(`{"$code":`)
256 writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
257 buf.WriteString(`,"$scope":{`)
258
259 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
260
261 ejvw.push(mCodeWithScope)
262 return ejvw, nil
263}
264
265func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
266 if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
267 return err
268 }
269
270 var buf bytes.Buffer
271 buf.WriteString(`{"$dbPointer":{"$ref":"`)
272 buf.WriteString(ns)
273 buf.WriteString(`","$id":{"$oid":"`)
274 buf.WriteString(oid.Hex())
275 buf.WriteString(`"}}},`)
276
277 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
278
279 ejvw.pop()
280 return nil
281}
282
283func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
284 if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
285 return err
286 }
287
288 t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
289
290 if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
291 s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
292 ejvw.writeExtendedSingleValue("date", s, false)
293 } else {
294 ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
295 }
296
297 ejvw.buf = append(ejvw.buf, ',')
298
299 ejvw.pop()
300 return nil
301}
302
303func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error {
304 if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
305 return err
306 }
307
308 ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
309 ejvw.buf = append(ejvw.buf, ',')
310
311 ejvw.pop()
312 return nil
313}
314
315func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
316 if ejvw.stack[ejvw.frame].mode == mTopLevel {
317 ejvw.buf = append(ejvw.buf, '{')
318 return ejvw, nil
319 }
320
321 if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
322 return nil, err
323 }
324
325 ejvw.buf = append(ejvw.buf, '{')
326 ejvw.push(mDocument)
327 return ejvw, nil
328}
329
330func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
331 if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
332 return err
333 }
334
335 s := formatDouble(f)
336
337 if ejvw.canonical {
338 ejvw.writeExtendedSingleValue("numberDouble", s, true)
339 } else {
340 switch s {
341 case "Infinity":
342 fallthrough
343 case "-Infinity":
344 fallthrough
345 case "NaN":
346 s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
347 }
348 ejvw.buf = append(ejvw.buf, []byte(s)...)
349 }
350
351 ejvw.buf = append(ejvw.buf, ',')
352
353 ejvw.pop()
354 return nil
355}
356
357func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
358 if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
359 return err
360 }
361
362 s := strconv.FormatInt(int64(i), 10)
363
364 if ejvw.canonical {
365 ejvw.writeExtendedSingleValue("numberInt", s, true)
366 } else {
367 ejvw.buf = append(ejvw.buf, []byte(s)...)
368 }
369
370 ejvw.buf = append(ejvw.buf, ',')
371
372 ejvw.pop()
373 return nil
374}
375
376func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
377 if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
378 return err
379 }
380
381 s := strconv.FormatInt(i, 10)
382
383 if ejvw.canonical {
384 ejvw.writeExtendedSingleValue("numberLong", s, true)
385 } else {
386 ejvw.buf = append(ejvw.buf, []byte(s)...)
387 }
388
389 ejvw.buf = append(ejvw.buf, ',')
390
391 ejvw.pop()
392 return nil
393}
394
395func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
396 if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
397 return err
398 }
399
400 var buf bytes.Buffer
401 writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
402
403 ejvw.writeExtendedSingleValue("code", buf.String(), false)
404 ejvw.buf = append(ejvw.buf, ',')
405
406 ejvw.pop()
407 return nil
408}
409
410func (ejvw *extJSONValueWriter) WriteMaxKey() error {
411 if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
412 return err
413 }
414
415 ejvw.writeExtendedSingleValue("maxKey", "1", false)
416 ejvw.buf = append(ejvw.buf, ',')
417
418 ejvw.pop()
419 return nil
420}
421
422func (ejvw *extJSONValueWriter) WriteMinKey() error {
423 if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
424 return err
425 }
426
427 ejvw.writeExtendedSingleValue("minKey", "1", false)
428 ejvw.buf = append(ejvw.buf, ',')
429
430 ejvw.pop()
431 return nil
432}
433
434func (ejvw *extJSONValueWriter) WriteNull() error {
435 if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
436 return err
437 }
438
439 ejvw.buf = append(ejvw.buf, []byte("null")...)
440 ejvw.buf = append(ejvw.buf, ',')
441
442 ejvw.pop()
443 return nil
444}
445
446func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error {
447 if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
448 return err
449 }
450
451 ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
452 ejvw.buf = append(ejvw.buf, ',')
453
454 ejvw.pop()
455 return nil
456}
457
458func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
459 if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
460 return err
461 }
462
463 var buf bytes.Buffer
464 buf.WriteString(`{"$regularExpression":{"pattern":`)
465 writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
466 buf.WriteString(`,"options":"`)
467 buf.WriteString(sortStringAlphebeticAscending(options))
468 buf.WriteString(`"}},`)
469
470 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
471
472 ejvw.pop()
473 return nil
474}
475
476func (ejvw *extJSONValueWriter) WriteString(s string) error {
477 if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
478 return err
479 }
480
481 var buf bytes.Buffer
482 writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
483
484 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
485 ejvw.buf = append(ejvw.buf, ',')
486
487 ejvw.pop()
488 return nil
489}
490
491func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
492 if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
493 return err
494 }
495
496 var buf bytes.Buffer
497 writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
498
499 ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
500 ejvw.buf = append(ejvw.buf, ',')
501
502 ejvw.pop()
503 return nil
504}
505
506func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
507 if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
508 return err
509 }
510
511 var buf bytes.Buffer
512 buf.WriteString(`{"$timestamp":{"t":`)
513 buf.WriteString(strconv.FormatUint(uint64(t), 10))
514 buf.WriteString(`,"i":`)
515 buf.WriteString(strconv.FormatUint(uint64(i), 10))
516 buf.WriteString(`}},`)
517
518 ejvw.buf = append(ejvw.buf, buf.Bytes()...)
519
520 ejvw.pop()
521 return nil
522}
523
524func (ejvw *extJSONValueWriter) WriteUndefined() error {
525 if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
526 return err
527 }
528
529 ejvw.writeExtendedSingleValue("undefined", "true", false)
530 ejvw.buf = append(ejvw.buf, ',')
531
532 ejvw.pop()
533 return nil
534}
535
536func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
537 switch ejvw.stack[ejvw.frame].mode {
538 case mDocument, mTopLevel, mCodeWithScope:
539 ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`"%s":`, key))...)
540 ejvw.push(mElement)
541 default:
542 return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
543 }
544
545 return ejvw, nil
546}
547
548func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
549 switch ejvw.stack[ejvw.frame].mode {
550 case mDocument, mTopLevel, mCodeWithScope:
551 default:
552 return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
553 }
554
555 // close the document
556 if ejvw.buf[len(ejvw.buf)-1] == ',' {
557 ejvw.buf[len(ejvw.buf)-1] = '}'
558 } else {
559 ejvw.buf = append(ejvw.buf, '}')
560 }
561
562 switch ejvw.stack[ejvw.frame].mode {
563 case mCodeWithScope:
564 ejvw.buf = append(ejvw.buf, '}')
565 fallthrough
566 case mDocument:
567 ejvw.buf = append(ejvw.buf, ',')
568 case mTopLevel:
569 if ejvw.w != nil {
570 if _, err := ejvw.w.Write(ejvw.buf); err != nil {
571 return err
572 }
573 ejvw.buf = ejvw.buf[:0]
574 }
575 }
576
577 ejvw.pop()
578 return nil
579}
580
581func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
582 switch ejvw.stack[ejvw.frame].mode {
583 case mArray:
584 ejvw.push(mValue)
585 default:
586 return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
587 }
588
589 return ejvw, nil
590}
591
592func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
593 switch ejvw.stack[ejvw.frame].mode {
594 case mArray:
595 // close the array
596 if ejvw.buf[len(ejvw.buf)-1] == ',' {
597 ejvw.buf[len(ejvw.buf)-1] = ']'
598 } else {
599 ejvw.buf = append(ejvw.buf, ']')
600 }
601
602 ejvw.buf = append(ejvw.buf, ',')
603
604 ejvw.pop()
605 default:
606 return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
607 }
608
609 return nil
610}
611
612func formatDouble(f float64) string {
613 var s string
614 if math.IsInf(f, 1) {
615 s = "Infinity"
616 } else if math.IsInf(f, -1) {
617 s = "-Infinity"
618 } else if math.IsNaN(f) {
619 s = "NaN"
620 } else {
621 // Print exactly one decimalType place for integers; otherwise, print as many are necessary to
622 // perfectly represent it.
623 s = strconv.FormatFloat(f, 'G', -1, 64)
624 if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
625 s += ".0"
626 }
627 }
628
629 return s
630}
631
632var hexChars = "0123456789abcdef"
633
634func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
635 buf.WriteByte('"')
636 start := 0
637 for i := 0; i < len(s); {
638 if b := s[i]; b < utf8.RuneSelf {
639 if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
640 i++
641 continue
642 }
643 if start < i {
644 buf.WriteString(s[start:i])
645 }
646 switch b {
647 case '\\', '"':
648 buf.WriteByte('\\')
649 buf.WriteByte(b)
650 case '\n':
651 buf.WriteByte('\\')
652 buf.WriteByte('n')
653 case '\r':
654 buf.WriteByte('\\')
655 buf.WriteByte('r')
656 case '\t':
657 buf.WriteByte('\\')
658 buf.WriteByte('t')
659 case '\b':
660 buf.WriteByte('\\')
661 buf.WriteByte('b')
662 case '\f':
663 buf.WriteByte('\\')
664 buf.WriteByte('f')
665 default:
666 // This encodes bytes < 0x20 except for \t, \n and \r.
667 // If escapeHTML is set, it also escapes <, >, and &
668 // because they can lead to security holes when
669 // user-controlled strings are rendered into JSON
670 // and served to some browsers.
671 buf.WriteString(`\u00`)
672 buf.WriteByte(hexChars[b>>4])
673 buf.WriteByte(hexChars[b&0xF])
674 }
675 i++
676 start = i
677 continue
678 }
679 c, size := utf8.DecodeRuneInString(s[i:])
680 if c == utf8.RuneError && size == 1 {
681 if start < i {
682 buf.WriteString(s[start:i])
683 }
684 buf.WriteString(`\ufffd`)
685 i += size
686 start = i
687 continue
688 }
689 // U+2028 is LINE SEPARATOR.
690 // U+2029 is PARAGRAPH SEPARATOR.
691 // They are both technically valid characters in JSON strings,
692 // but don't work in JSONP, which has to be evaluated as JavaScript,
693 // and can lead to security holes there. It is valid JSON to
694 // escape them, so we do so unconditionally.
695 // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
696 if c == '\u2028' || c == '\u2029' {
697 if start < i {
698 buf.WriteString(s[start:i])
699 }
700 buf.WriteString(`\u202`)
701 buf.WriteByte(hexChars[c&0xF])
702 i += size
703 start = i
704 continue
705 }
706 i += size
707 }
708 if start < len(s) {
709 buf.WriteString(s[start:])
710 }
711 buf.WriteByte('"')
712}
713
714type sortableString []rune
715
716func (ss sortableString) Len() int {
717 return len(ss)
718}
719
720func (ss sortableString) Less(i, j int) bool {
721 return ss[i] < ss[j]
722}
723
724func (ss sortableString) Swap(i, j int) {
725 oldI := ss[i]
726 ss[i] = ss[j]
727 ss[j] = oldI
728}
729
730func sortStringAlphebeticAscending(s string) string {
731 ss := sortableString([]rune(s))
732 sort.Sort(ss)
733 return string([]rune(ss))
734}