blob: 34f3b4cc5c67180e9b4a1fa6d3d702570072a31a [file] [log] [blame]
Don Newton379ae252019-04-01 12:17:06 -04001// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package bsoncodec
8
9import (
10 "errors"
11 "fmt"
12 "reflect"
13 "sync"
14
15 "github.com/mongodb/mongo-go-driver/bson/bsonrw"
16 "github.com/mongodb/mongo-go-driver/bson/bsontype"
17)
18
19var defaultStructCodec = &StructCodec{
20 cache: make(map[reflect.Type]*structDescription),
21 parser: DefaultStructTagParser,
22}
23
24// Zeroer allows custom struct types to implement a report of zero
25// state. All struct types that don't implement Zeroer or where IsZero
26// returns false are considered to be not zero.
27type Zeroer interface {
28 IsZero() bool
29}
30
31// StructCodec is the Codec used for struct values.
32type StructCodec struct {
33 cache map[reflect.Type]*structDescription
34 l sync.RWMutex
35 parser StructTagParser
36}
37
38var _ ValueEncoder = &StructCodec{}
39var _ ValueDecoder = &StructCodec{}
40
41// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
42func NewStructCodec(p StructTagParser) (*StructCodec, error) {
43 if p == nil {
44 return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
45 }
46
47 return &StructCodec{
48 cache: make(map[reflect.Type]*structDescription),
49 parser: p,
50 }, nil
51}
52
53// EncodeValue handles encoding generic struct types.
54func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
55 if !val.IsValid() || val.Kind() != reflect.Struct {
56 return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
57 }
58
59 sd, err := sc.describeStruct(r.Registry, val.Type())
60 if err != nil {
61 return err
62 }
63
64 dw, err := vw.WriteDocument()
65 if err != nil {
66 return err
67 }
68 var rv reflect.Value
69 for _, desc := range sd.fl {
70 if desc.inline == nil {
71 rv = val.Field(desc.idx)
72 } else {
73 rv = val.FieldByIndex(desc.inline)
74 }
75
76 if desc.encoder == nil {
77 return ErrNoEncoder{Type: rv.Type()}
78 }
79
80 encoder := desc.encoder
81
82 iszero := sc.isZero
83 if iz, ok := encoder.(CodecZeroer); ok {
84 iszero = iz.IsTypeZero
85 }
86
87 if desc.omitEmpty && iszero(rv.Interface()) {
88 continue
89 }
90
91 vw2, err := dw.WriteDocumentElement(desc.name)
92 if err != nil {
93 return err
94 }
95
96 ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
97 err = encoder.EncodeValue(ectx, vw2, rv)
98 if err != nil {
99 return err
100 }
101 }
102
103 if sd.inlineMap >= 0 {
104 rv := val.Field(sd.inlineMap)
105 collisionFn := func(key string) bool {
106 _, exists := sd.fm[key]
107 return exists
108 }
109
110 return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
111 }
112
113 return dw.WriteDocumentEnd()
114}
115
116// DecodeValue implements the Codec interface.
117func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
118 if !val.CanSet() || val.Kind() != reflect.Struct {
119 return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
120 }
121
122 switch vr.Type() {
123 case bsontype.Type(0), bsontype.EmbeddedDocument:
124 default:
125 return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
126 }
127
128 sd, err := sc.describeStruct(r.Registry, val.Type())
129 if err != nil {
130 return err
131 }
132
133 var decoder ValueDecoder
134 var inlineMap reflect.Value
135 if sd.inlineMap >= 0 {
136 inlineMap = val.Field(sd.inlineMap)
137 if inlineMap.IsNil() {
138 inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
139 }
140 decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
141 if err != nil {
142 return err
143 }
144 }
145
146 dr, err := vr.ReadDocument()
147 if err != nil {
148 return err
149 }
150
151 for {
152 name, vr, err := dr.ReadElement()
153 if err == bsonrw.ErrEOD {
154 break
155 }
156 if err != nil {
157 return err
158 }
159
160 fd, exists := sd.fm[name]
161 if !exists {
162 if sd.inlineMap < 0 {
163 // The encoding/json package requires a flag to return on error for non-existent fields.
164 // This functionality seems appropriate for the struct codec.
165 err = vr.Skip()
166 if err != nil {
167 return err
168 }
169 continue
170 }
171
172 elem := reflect.New(inlineMap.Type().Elem()).Elem()
173 err = decoder.DecodeValue(r, vr, elem)
174 if err != nil {
175 return err
176 }
177 inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
178 continue
179 }
180
181 var field reflect.Value
182 if fd.inline == nil {
183 field = val.Field(fd.idx)
184 } else {
185 field = val.FieldByIndex(fd.inline)
186 }
187
188 if !field.CanSet() { // Being settable is a super set of being addressable.
189 return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
190 }
191 if field.Kind() == reflect.Ptr && field.IsNil() {
192 field.Set(reflect.New(field.Type().Elem()))
193 }
194 field = field.Addr()
195
196 dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate}
197 if fd.decoder == nil {
198 return ErrNoDecoder{Type: field.Elem().Type()}
199 }
200
201 if decoder, ok := fd.decoder.(ValueDecoder); ok {
202 err = decoder.DecodeValue(dctx, vr, field.Elem())
203 if err != nil {
204 return err
205 }
206 continue
207 }
208 err = fd.decoder.DecodeValue(dctx, vr, field)
209 if err != nil {
210 return err
211 }
212 }
213
214 return nil
215}
216
217func (sc *StructCodec) isZero(i interface{}) bool {
218 v := reflect.ValueOf(i)
219
220 // check the value validity
221 if !v.IsValid() {
222 return true
223 }
224
225 if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
226 return z.IsZero()
227 }
228
229 switch v.Kind() {
230 case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
231 return v.Len() == 0
232 case reflect.Bool:
233 return !v.Bool()
234 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
235 return v.Int() == 0
236 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
237 return v.Uint() == 0
238 case reflect.Float32, reflect.Float64:
239 return v.Float() == 0
240 case reflect.Interface, reflect.Ptr:
241 return v.IsNil()
242 }
243
244 return false
245}
246
247type structDescription struct {
248 fm map[string]fieldDescription
249 fl []fieldDescription
250 inlineMap int
251}
252
253type fieldDescription struct {
254 name string
255 idx int
256 omitEmpty bool
257 minSize bool
258 truncate bool
259 inline []int
260 encoder ValueEncoder
261 decoder ValueDecoder
262}
263
264func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
265 // We need to analyze the struct, including getting the tags, collecting
266 // information about inlining, and create a map of the field name to the field.
267 sc.l.RLock()
268 ds, exists := sc.cache[t]
269 sc.l.RUnlock()
270 if exists {
271 return ds, nil
272 }
273
274 numFields := t.NumField()
275 sd := &structDescription{
276 fm: make(map[string]fieldDescription, numFields),
277 fl: make([]fieldDescription, 0, numFields),
278 inlineMap: -1,
279 }
280
281 for i := 0; i < numFields; i++ {
282 sf := t.Field(i)
283 if sf.PkgPath != "" {
284 // unexported, ignore
285 continue
286 }
287
288 encoder, err := r.LookupEncoder(sf.Type)
289 if err != nil {
290 encoder = nil
291 }
292 decoder, err := r.LookupDecoder(sf.Type)
293 if err != nil {
294 decoder = nil
295 }
296
297 description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
298
299 stags, err := sc.parser.ParseStructTags(sf)
300 if err != nil {
301 return nil, err
302 }
303 if stags.Skip {
304 continue
305 }
306 description.name = stags.Name
307 description.omitEmpty = stags.OmitEmpty
308 description.minSize = stags.MinSize
309 description.truncate = stags.Truncate
310
311 if stags.Inline {
312 switch sf.Type.Kind() {
313 case reflect.Map:
314 if sd.inlineMap >= 0 {
315 return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
316 }
317 if sf.Type.Key() != tString {
318 return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
319 }
320 sd.inlineMap = description.idx
321 case reflect.Struct:
322 inlinesf, err := sc.describeStruct(r, sf.Type)
323 if err != nil {
324 return nil, err
325 }
326 for _, fd := range inlinesf.fl {
327 if _, exists := sd.fm[fd.name]; exists {
328 return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
329 }
330 if fd.inline == nil {
331 fd.inline = []int{i, fd.idx}
332 } else {
333 fd.inline = append([]int{i}, fd.inline...)
334 }
335 sd.fm[fd.name] = fd
336 sd.fl = append(sd.fl, fd)
337 }
338 default:
339 return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
340 }
341 continue
342 }
343
344 if _, exists := sd.fm[description.name]; exists {
345 return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
346 }
347
348 sd.fm[description.name] = description
349 sd.fl = append(sd.fl, description)
350 }
351
352 sc.l.Lock()
353 sc.cache[t] = sd
354 sc.l.Unlock()
355
356 return sd, nil
357}