blob: 60dcf70d1e6c6dcd9a34cd6c9120691de0b37294 [file] [log] [blame]
Takahiro Suzukid7bf8202020-12-17 20:21:59 +09001// Go support for Protocol Buffers - Google's data interchange format
2//
3// Copyright 2016 The Go Authors. All rights reserved.
4// https://github.com/golang/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10// * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12// * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16// * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32package proto
33
34import (
35 "fmt"
36 "reflect"
37 "strings"
38 "sync"
39 "sync/atomic"
40)
41
42// Merge merges the src message into dst.
43// This assumes that dst and src of the same type and are non-nil.
44func (a *InternalMessageInfo) Merge(dst, src Message) {
45 mi := atomicLoadMergeInfo(&a.merge)
46 if mi == nil {
47 mi = getMergeInfo(reflect.TypeOf(dst).Elem())
48 atomicStoreMergeInfo(&a.merge, mi)
49 }
50 mi.merge(toPointer(&dst), toPointer(&src))
51}
52
53type mergeInfo struct {
54 typ reflect.Type
55
56 initialized int32 // 0: only typ is valid, 1: everything is valid
57 lock sync.Mutex
58
59 fields []mergeFieldInfo
60 unrecognized field // Offset of XXX_unrecognized
61}
62
63type mergeFieldInfo struct {
64 field field // Offset of field, guaranteed to be valid
65
66 // isPointer reports whether the value in the field is a pointer.
67 // This is true for the following situations:
68 // * Pointer to struct
69 // * Pointer to basic type (proto2 only)
70 // * Slice (first value in slice header is a pointer)
71 // * String (first value in string header is a pointer)
72 isPointer bool
73
74 // basicWidth reports the width of the field assuming that it is directly
75 // embedded in the struct (as is the case for basic types in proto3).
76 // The possible values are:
77 // 0: invalid
78 // 1: bool
79 // 4: int32, uint32, float32
80 // 8: int64, uint64, float64
81 basicWidth int
82
83 // Where dst and src are pointers to the types being merged.
84 merge func(dst, src pointer)
85}
86
87var (
88 mergeInfoMap = map[reflect.Type]*mergeInfo{}
89 mergeInfoLock sync.Mutex
90)
91
92func getMergeInfo(t reflect.Type) *mergeInfo {
93 mergeInfoLock.Lock()
94 defer mergeInfoLock.Unlock()
95 mi := mergeInfoMap[t]
96 if mi == nil {
97 mi = &mergeInfo{typ: t}
98 mergeInfoMap[t] = mi
99 }
100 return mi
101}
102
103// merge merges src into dst assuming they are both of type *mi.typ.
104func (mi *mergeInfo) merge(dst, src pointer) {
105 if dst.isNil() {
106 panic("proto: nil destination")
107 }
108 if src.isNil() {
109 return // Nothing to do.
110 }
111
112 if atomic.LoadInt32(&mi.initialized) == 0 {
113 mi.computeMergeInfo()
114 }
115
116 for _, fi := range mi.fields {
117 sfp := src.offset(fi.field)
118
119 // As an optimization, we can avoid the merge function call cost
120 // if we know for sure that the source will have no effect
121 // by checking if it is the zero value.
122 if unsafeAllowed {
123 if fi.isPointer && sfp.getPointer().isNil() { // Could be slice or string
124 continue
125 }
126 if fi.basicWidth > 0 {
127 switch {
128 case fi.basicWidth == 1 && !*sfp.toBool():
129 continue
130 case fi.basicWidth == 4 && *sfp.toUint32() == 0:
131 continue
132 case fi.basicWidth == 8 && *sfp.toUint64() == 0:
133 continue
134 }
135 }
136 }
137
138 dfp := dst.offset(fi.field)
139 fi.merge(dfp, sfp)
140 }
141
142 // TODO: Make this faster?
143 out := dst.asPointerTo(mi.typ).Elem()
144 in := src.asPointerTo(mi.typ).Elem()
145 if emIn, err := extendable(in.Addr().Interface()); err == nil {
146 emOut, _ := extendable(out.Addr().Interface())
147 mIn, muIn := emIn.extensionsRead()
148 if mIn != nil {
149 mOut := emOut.extensionsWrite()
150 muIn.Lock()
151 mergeExtension(mOut, mIn)
152 muIn.Unlock()
153 }
154 }
155
156 if mi.unrecognized.IsValid() {
157 if b := *src.offset(mi.unrecognized).toBytes(); len(b) > 0 {
158 *dst.offset(mi.unrecognized).toBytes() = append([]byte(nil), b...)
159 }
160 }
161}
162
163func (mi *mergeInfo) computeMergeInfo() {
164 mi.lock.Lock()
165 defer mi.lock.Unlock()
166 if mi.initialized != 0 {
167 return
168 }
169 t := mi.typ
170 n := t.NumField()
171
172 props := GetProperties(t)
173 for i := 0; i < n; i++ {
174 f := t.Field(i)
175 if strings.HasPrefix(f.Name, "XXX_") {
176 continue
177 }
178
179 mfi := mergeFieldInfo{field: toField(&f)}
180 tf := f.Type
181
182 // As an optimization, we can avoid the merge function call cost
183 // if we know for sure that the source will have no effect
184 // by checking if it is the zero value.
185 if unsafeAllowed {
186 switch tf.Kind() {
187 case reflect.Ptr, reflect.Slice, reflect.String:
188 // As a special case, we assume slices and strings are pointers
189 // since we know that the first field in the SliceSlice or
190 // StringHeader is a data pointer.
191 mfi.isPointer = true
192 case reflect.Bool:
193 mfi.basicWidth = 1
194 case reflect.Int32, reflect.Uint32, reflect.Float32:
195 mfi.basicWidth = 4
196 case reflect.Int64, reflect.Uint64, reflect.Float64:
197 mfi.basicWidth = 8
198 }
199 }
200
201 // Unwrap tf to get at its most basic type.
202 var isPointer, isSlice bool
203 if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
204 isSlice = true
205 tf = tf.Elem()
206 }
207 if tf.Kind() == reflect.Ptr {
208 isPointer = true
209 tf = tf.Elem()
210 }
211 if isPointer && isSlice && tf.Kind() != reflect.Struct {
212 panic("both pointer and slice for basic type in " + tf.Name())
213 }
214
215 switch tf.Kind() {
216 case reflect.Int32:
217 switch {
218 case isSlice: // E.g., []int32
219 mfi.merge = func(dst, src pointer) {
220 // NOTE: toInt32Slice is not defined (see pointer_reflect.go).
221 /*
222 sfsp := src.toInt32Slice()
223 if *sfsp != nil {
224 dfsp := dst.toInt32Slice()
225 *dfsp = append(*dfsp, *sfsp...)
226 if *dfsp == nil {
227 *dfsp = []int64{}
228 }
229 }
230 */
231 sfs := src.getInt32Slice()
232 if sfs != nil {
233 dfs := dst.getInt32Slice()
234 dfs = append(dfs, sfs...)
235 if dfs == nil {
236 dfs = []int32{}
237 }
238 dst.setInt32Slice(dfs)
239 }
240 }
241 case isPointer: // E.g., *int32
242 mfi.merge = func(dst, src pointer) {
243 // NOTE: toInt32Ptr is not defined (see pointer_reflect.go).
244 /*
245 sfpp := src.toInt32Ptr()
246 if *sfpp != nil {
247 dfpp := dst.toInt32Ptr()
248 if *dfpp == nil {
249 *dfpp = Int32(**sfpp)
250 } else {
251 **dfpp = **sfpp
252 }
253 }
254 */
255 sfp := src.getInt32Ptr()
256 if sfp != nil {
257 dfp := dst.getInt32Ptr()
258 if dfp == nil {
259 dst.setInt32Ptr(*sfp)
260 } else {
261 *dfp = *sfp
262 }
263 }
264 }
265 default: // E.g., int32
266 mfi.merge = func(dst, src pointer) {
267 if v := *src.toInt32(); v != 0 {
268 *dst.toInt32() = v
269 }
270 }
271 }
272 case reflect.Int64:
273 switch {
274 case isSlice: // E.g., []int64
275 mfi.merge = func(dst, src pointer) {
276 sfsp := src.toInt64Slice()
277 if *sfsp != nil {
278 dfsp := dst.toInt64Slice()
279 *dfsp = append(*dfsp, *sfsp...)
280 if *dfsp == nil {
281 *dfsp = []int64{}
282 }
283 }
284 }
285 case isPointer: // E.g., *int64
286 mfi.merge = func(dst, src pointer) {
287 sfpp := src.toInt64Ptr()
288 if *sfpp != nil {
289 dfpp := dst.toInt64Ptr()
290 if *dfpp == nil {
291 *dfpp = Int64(**sfpp)
292 } else {
293 **dfpp = **sfpp
294 }
295 }
296 }
297 default: // E.g., int64
298 mfi.merge = func(dst, src pointer) {
299 if v := *src.toInt64(); v != 0 {
300 *dst.toInt64() = v
301 }
302 }
303 }
304 case reflect.Uint32:
305 switch {
306 case isSlice: // E.g., []uint32
307 mfi.merge = func(dst, src pointer) {
308 sfsp := src.toUint32Slice()
309 if *sfsp != nil {
310 dfsp := dst.toUint32Slice()
311 *dfsp = append(*dfsp, *sfsp...)
312 if *dfsp == nil {
313 *dfsp = []uint32{}
314 }
315 }
316 }
317 case isPointer: // E.g., *uint32
318 mfi.merge = func(dst, src pointer) {
319 sfpp := src.toUint32Ptr()
320 if *sfpp != nil {
321 dfpp := dst.toUint32Ptr()
322 if *dfpp == nil {
323 *dfpp = Uint32(**sfpp)
324 } else {
325 **dfpp = **sfpp
326 }
327 }
328 }
329 default: // E.g., uint32
330 mfi.merge = func(dst, src pointer) {
331 if v := *src.toUint32(); v != 0 {
332 *dst.toUint32() = v
333 }
334 }
335 }
336 case reflect.Uint64:
337 switch {
338 case isSlice: // E.g., []uint64
339 mfi.merge = func(dst, src pointer) {
340 sfsp := src.toUint64Slice()
341 if *sfsp != nil {
342 dfsp := dst.toUint64Slice()
343 *dfsp = append(*dfsp, *sfsp...)
344 if *dfsp == nil {
345 *dfsp = []uint64{}
346 }
347 }
348 }
349 case isPointer: // E.g., *uint64
350 mfi.merge = func(dst, src pointer) {
351 sfpp := src.toUint64Ptr()
352 if *sfpp != nil {
353 dfpp := dst.toUint64Ptr()
354 if *dfpp == nil {
355 *dfpp = Uint64(**sfpp)
356 } else {
357 **dfpp = **sfpp
358 }
359 }
360 }
361 default: // E.g., uint64
362 mfi.merge = func(dst, src pointer) {
363 if v := *src.toUint64(); v != 0 {
364 *dst.toUint64() = v
365 }
366 }
367 }
368 case reflect.Float32:
369 switch {
370 case isSlice: // E.g., []float32
371 mfi.merge = func(dst, src pointer) {
372 sfsp := src.toFloat32Slice()
373 if *sfsp != nil {
374 dfsp := dst.toFloat32Slice()
375 *dfsp = append(*dfsp, *sfsp...)
376 if *dfsp == nil {
377 *dfsp = []float32{}
378 }
379 }
380 }
381 case isPointer: // E.g., *float32
382 mfi.merge = func(dst, src pointer) {
383 sfpp := src.toFloat32Ptr()
384 if *sfpp != nil {
385 dfpp := dst.toFloat32Ptr()
386 if *dfpp == nil {
387 *dfpp = Float32(**sfpp)
388 } else {
389 **dfpp = **sfpp
390 }
391 }
392 }
393 default: // E.g., float32
394 mfi.merge = func(dst, src pointer) {
395 if v := *src.toFloat32(); v != 0 {
396 *dst.toFloat32() = v
397 }
398 }
399 }
400 case reflect.Float64:
401 switch {
402 case isSlice: // E.g., []float64
403 mfi.merge = func(dst, src pointer) {
404 sfsp := src.toFloat64Slice()
405 if *sfsp != nil {
406 dfsp := dst.toFloat64Slice()
407 *dfsp = append(*dfsp, *sfsp...)
408 if *dfsp == nil {
409 *dfsp = []float64{}
410 }
411 }
412 }
413 case isPointer: // E.g., *float64
414 mfi.merge = func(dst, src pointer) {
415 sfpp := src.toFloat64Ptr()
416 if *sfpp != nil {
417 dfpp := dst.toFloat64Ptr()
418 if *dfpp == nil {
419 *dfpp = Float64(**sfpp)
420 } else {
421 **dfpp = **sfpp
422 }
423 }
424 }
425 default: // E.g., float64
426 mfi.merge = func(dst, src pointer) {
427 if v := *src.toFloat64(); v != 0 {
428 *dst.toFloat64() = v
429 }
430 }
431 }
432 case reflect.Bool:
433 switch {
434 case isSlice: // E.g., []bool
435 mfi.merge = func(dst, src pointer) {
436 sfsp := src.toBoolSlice()
437 if *sfsp != nil {
438 dfsp := dst.toBoolSlice()
439 *dfsp = append(*dfsp, *sfsp...)
440 if *dfsp == nil {
441 *dfsp = []bool{}
442 }
443 }
444 }
445 case isPointer: // E.g., *bool
446 mfi.merge = func(dst, src pointer) {
447 sfpp := src.toBoolPtr()
448 if *sfpp != nil {
449 dfpp := dst.toBoolPtr()
450 if *dfpp == nil {
451 *dfpp = Bool(**sfpp)
452 } else {
453 **dfpp = **sfpp
454 }
455 }
456 }
457 default: // E.g., bool
458 mfi.merge = func(dst, src pointer) {
459 if v := *src.toBool(); v {
460 *dst.toBool() = v
461 }
462 }
463 }
464 case reflect.String:
465 switch {
466 case isSlice: // E.g., []string
467 mfi.merge = func(dst, src pointer) {
468 sfsp := src.toStringSlice()
469 if *sfsp != nil {
470 dfsp := dst.toStringSlice()
471 *dfsp = append(*dfsp, *sfsp...)
472 if *dfsp == nil {
473 *dfsp = []string{}
474 }
475 }
476 }
477 case isPointer: // E.g., *string
478 mfi.merge = func(dst, src pointer) {
479 sfpp := src.toStringPtr()
480 if *sfpp != nil {
481 dfpp := dst.toStringPtr()
482 if *dfpp == nil {
483 *dfpp = String(**sfpp)
484 } else {
485 **dfpp = **sfpp
486 }
487 }
488 }
489 default: // E.g., string
490 mfi.merge = func(dst, src pointer) {
491 if v := *src.toString(); v != "" {
492 *dst.toString() = v
493 }
494 }
495 }
496 case reflect.Slice:
497 isProto3 := props.Prop[i].proto3
498 switch {
499 case isPointer:
500 panic("bad pointer in byte slice case in " + tf.Name())
501 case tf.Elem().Kind() != reflect.Uint8:
502 panic("bad element kind in byte slice case in " + tf.Name())
503 case isSlice: // E.g., [][]byte
504 mfi.merge = func(dst, src pointer) {
505 sbsp := src.toBytesSlice()
506 if *sbsp != nil {
507 dbsp := dst.toBytesSlice()
508 for _, sb := range *sbsp {
509 if sb == nil {
510 *dbsp = append(*dbsp, nil)
511 } else {
512 *dbsp = append(*dbsp, append([]byte{}, sb...))
513 }
514 }
515 if *dbsp == nil {
516 *dbsp = [][]byte{}
517 }
518 }
519 }
520 default: // E.g., []byte
521 mfi.merge = func(dst, src pointer) {
522 sbp := src.toBytes()
523 if *sbp != nil {
524 dbp := dst.toBytes()
525 if !isProto3 || len(*sbp) > 0 {
526 *dbp = append([]byte{}, *sbp...)
527 }
528 }
529 }
530 }
531 case reflect.Struct:
532 switch {
533 case isSlice && !isPointer: // E.g. []pb.T
534 mergeInfo := getMergeInfo(tf)
535 zero := reflect.Zero(tf)
536 mfi.merge = func(dst, src pointer) {
537 // TODO: Make this faster?
538 dstsp := dst.asPointerTo(f.Type)
539 dsts := dstsp.Elem()
540 srcs := src.asPointerTo(f.Type).Elem()
541 for i := 0; i < srcs.Len(); i++ {
542 dsts = reflect.Append(dsts, zero)
543 srcElement := srcs.Index(i).Addr()
544 dstElement := dsts.Index(dsts.Len() - 1).Addr()
545 mergeInfo.merge(valToPointer(dstElement), valToPointer(srcElement))
546 }
547 if dsts.IsNil() {
548 dsts = reflect.MakeSlice(f.Type, 0, 0)
549 }
550 dstsp.Elem().Set(dsts)
551 }
552 case !isPointer:
553 mergeInfo := getMergeInfo(tf)
554 mfi.merge = func(dst, src pointer) {
555 mergeInfo.merge(dst, src)
556 }
557 case isSlice: // E.g., []*pb.T
558 mergeInfo := getMergeInfo(tf)
559 mfi.merge = func(dst, src pointer) {
560 sps := src.getPointerSlice()
561 if sps != nil {
562 dps := dst.getPointerSlice()
563 for _, sp := range sps {
564 var dp pointer
565 if !sp.isNil() {
566 dp = valToPointer(reflect.New(tf))
567 mergeInfo.merge(dp, sp)
568 }
569 dps = append(dps, dp)
570 }
571 if dps == nil {
572 dps = []pointer{}
573 }
574 dst.setPointerSlice(dps)
575 }
576 }
577 default: // E.g., *pb.T
578 mergeInfo := getMergeInfo(tf)
579 mfi.merge = func(dst, src pointer) {
580 sp := src.getPointer()
581 if !sp.isNil() {
582 dp := dst.getPointer()
583 if dp.isNil() {
584 dp = valToPointer(reflect.New(tf))
585 dst.setPointer(dp)
586 }
587 mergeInfo.merge(dp, sp)
588 }
589 }
590 }
591 case reflect.Map:
592 switch {
593 case isPointer || isSlice:
594 panic("bad pointer or slice in map case in " + tf.Name())
595 default: // E.g., map[K]V
596 mfi.merge = func(dst, src pointer) {
597 sm := src.asPointerTo(tf).Elem()
598 if sm.Len() == 0 {
599 return
600 }
601 dm := dst.asPointerTo(tf).Elem()
602 if dm.IsNil() {
603 dm.Set(reflect.MakeMap(tf))
604 }
605
606 switch tf.Elem().Kind() {
607 case reflect.Ptr: // Proto struct (e.g., *T)
608 for _, key := range sm.MapKeys() {
609 val := sm.MapIndex(key)
610 val = reflect.ValueOf(Clone(val.Interface().(Message)))
611 dm.SetMapIndex(key, val)
612 }
613 case reflect.Slice: // E.g. Bytes type (e.g., []byte)
614 for _, key := range sm.MapKeys() {
615 val := sm.MapIndex(key)
616 val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
617 dm.SetMapIndex(key, val)
618 }
619 default: // Basic type (e.g., string)
620 for _, key := range sm.MapKeys() {
621 val := sm.MapIndex(key)
622 dm.SetMapIndex(key, val)
623 }
624 }
625 }
626 }
627 case reflect.Interface:
628 // Must be oneof field.
629 switch {
630 case isPointer || isSlice:
631 panic("bad pointer or slice in interface case in " + tf.Name())
632 default: // E.g., interface{}
633 // TODO: Make this faster?
634 mfi.merge = func(dst, src pointer) {
635 su := src.asPointerTo(tf).Elem()
636 if !su.IsNil() {
637 du := dst.asPointerTo(tf).Elem()
638 typ := su.Elem().Type()
639 if du.IsNil() || du.Elem().Type() != typ {
640 du.Set(reflect.New(typ.Elem())) // Initialize interface if empty
641 }
642 sv := su.Elem().Elem().Field(0)
643 if sv.Kind() == reflect.Ptr && sv.IsNil() {
644 return
645 }
646 dv := du.Elem().Elem().Field(0)
647 if dv.Kind() == reflect.Ptr && dv.IsNil() {
648 dv.Set(reflect.New(sv.Type().Elem())) // Initialize proto message if empty
649 }
650 switch sv.Type().Kind() {
651 case reflect.Ptr: // Proto struct (e.g., *T)
652 Merge(dv.Interface().(Message), sv.Interface().(Message))
653 case reflect.Slice: // E.g. Bytes type (e.g., []byte)
654 dv.Set(reflect.ValueOf(append([]byte{}, sv.Bytes()...)))
655 default: // Basic type (e.g., string)
656 dv.Set(sv)
657 }
658 }
659 }
660 }
661 default:
662 panic(fmt.Sprintf("merger not found for type:%s", tf))
663 }
664 mi.fields = append(mi.fields, mfi)
665 }
666
667 mi.unrecognized = invalidField
668 if f, ok := t.FieldByName("XXX_unrecognized"); ok {
669 if f.Type != reflect.TypeOf([]byte{}) {
670 panic("expected XXX_unrecognized to be of type []byte")
671 }
672 mi.unrecognized = toField(&f)
673 }
674
675 atomic.StoreInt32(&mi.initialized, 1)
676}