blob: c65bbc0446ea8309cc5b2e6186138e172a5c6d34 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 "google.golang.org/protobuf/proto"
12 pref "google.golang.org/protobuf/reflect/protoreflect"
13 piface "google.golang.org/protobuf/runtime/protoiface"
14)
15
16type mergeOptions struct{}
17
18func (o mergeOptions) Merge(dst, src proto.Message) {
19 proto.Merge(dst, src)
20}
21
22// merge is protoreflect.Methods.Merge.
23func (mi *MessageInfo) merge(in piface.MergeInput) piface.MergeOutput {
24 dp, ok := mi.getPointer(in.Destination)
25 if !ok {
26 return piface.MergeOutput{}
27 }
28 sp, ok := mi.getPointer(in.Source)
29 if !ok {
30 return piface.MergeOutput{}
31 }
32 mi.mergePointer(dp, sp, mergeOptions{})
33 return piface.MergeOutput{Flags: piface.MergeComplete}
34}
35
36func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
37 mi.init()
38 if dst.IsNil() {
39 panic(fmt.Sprintf("invalid value: merging into nil message"))
40 }
41 if src.IsNil() {
42 return
43 }
44 for _, f := range mi.orderedCoderFields {
45 if f.funcs.merge == nil {
46 continue
47 }
48 sfptr := src.Apply(f.offset)
49 if f.isPointer && sfptr.Elem().IsNil() {
50 continue
51 }
52 f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
53 }
54 if mi.extensionOffset.IsValid() {
55 sext := src.Apply(mi.extensionOffset).Extensions()
56 dext := dst.Apply(mi.extensionOffset).Extensions()
57 if *dext == nil {
58 *dext = make(map[int32]ExtensionField)
59 }
60 for num, sx := range *sext {
61 xt := sx.Type()
62 xi := getExtensionFieldInfo(xt)
63 if xi.funcs.merge == nil {
64 continue
65 }
66 dx := (*dext)[num]
67 var dv pref.Value
68 if dx.Type() == sx.Type() {
69 dv = dx.Value()
70 }
71 if !dv.IsValid() && xi.unmarshalNeedsValue {
72 dv = xt.New()
73 }
74 dv = xi.funcs.merge(dv, sx.Value(), opts)
75 dx.Set(sx.Type(), dv)
76 (*dext)[num] = dx
77 }
78 }
79 if mi.unknownOffset.IsValid() {
80 su := mi.getUnknownBytes(src)
81 if su != nil && len(*su) > 0 {
82 du := mi.mutableUnknownBytes(dst)
83 *du = append(*du, *su...)
84 }
85 }
86}
87
88func mergeScalarValue(dst, src pref.Value, opts mergeOptions) pref.Value {
89 return src
90}
91
92func mergeBytesValue(dst, src pref.Value, opts mergeOptions) pref.Value {
93 return pref.ValueOfBytes(append(emptyBuf[:], src.Bytes()...))
94}
95
96func mergeListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
97 dstl := dst.List()
98 srcl := src.List()
99 for i, llen := 0, srcl.Len(); i < llen; i++ {
100 dstl.Append(srcl.Get(i))
101 }
102 return dst
103}
104
105func mergeBytesListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
106 dstl := dst.List()
107 srcl := src.List()
108 for i, llen := 0, srcl.Len(); i < llen; i++ {
109 sb := srcl.Get(i).Bytes()
110 db := append(emptyBuf[:], sb...)
111 dstl.Append(pref.ValueOfBytes(db))
112 }
113 return dst
114}
115
116func mergeMessageListValue(dst, src pref.Value, opts mergeOptions) pref.Value {
117 dstl := dst.List()
118 srcl := src.List()
119 for i, llen := 0, srcl.Len(); i < llen; i++ {
120 sm := srcl.Get(i).Message()
121 dm := proto.Clone(sm.Interface()).ProtoReflect()
122 dstl.Append(pref.ValueOfMessage(dm))
123 }
124 return dst
125}
126
127func mergeMessageValue(dst, src pref.Value, opts mergeOptions) pref.Value {
128 opts.Merge(dst.Message().Interface(), src.Message().Interface())
129 return dst
130}
131
132func mergeMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
133 if f.mi != nil {
134 if dst.Elem().IsNil() {
135 dst.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
136 }
137 f.mi.mergePointer(dst.Elem(), src.Elem(), opts)
138 } else {
139 dm := dst.AsValueOf(f.ft).Elem()
140 sm := src.AsValueOf(f.ft).Elem()
141 if dm.IsNil() {
142 dm.Set(reflect.New(f.ft.Elem()))
143 }
144 opts.Merge(asMessage(dm), asMessage(sm))
145 }
146}
147
148func mergeMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
149 for _, sp := range src.PointerSlice() {
150 dm := reflect.New(f.ft.Elem().Elem())
151 if f.mi != nil {
152 f.mi.mergePointer(pointerOfValue(dm), sp, opts)
153 } else {
154 opts.Merge(asMessage(dm), asMessage(sp.AsValueOf(f.ft.Elem().Elem())))
155 }
156 dst.AppendPointerSlice(pointerOfValue(dm))
157 }
158}
159
160func mergeBytes(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
161 *dst.Bytes() = append(emptyBuf[:], *src.Bytes()...)
162}
163
164func mergeBytesNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
165 v := *src.Bytes()
166 if len(v) > 0 {
167 *dst.Bytes() = append(emptyBuf[:], v...)
168 }
169}
170
171func mergeBytesSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
172 ds := dst.BytesSlice()
173 for _, v := range *src.BytesSlice() {
174 *ds = append(*ds, append(emptyBuf[:], v...))
175 }
176}