blob: e44c6c53c57a8b3477db9196e9bd66a1100a4766 [file] [log] [blame]
khenaidoof3333552021-12-15 16:52:31 -05001package dynamic
2
3import (
4 "bytes"
5 "reflect"
6
7 "github.com/golang/protobuf/proto"
8
9 "github.com/jhump/protoreflect/desc"
10)
11
12// Equal returns true if the given two dynamic messages are equal. Two messages are equal when they
13// have the same message type and same fields set to equal values. For proto3 messages, fields set
14// to their zero value are considered unset.
15func Equal(a, b *Message) bool {
16 if a == b {
17 return true
18 }
19 if (a == nil) != (b == nil) {
20 return false
21 }
22 if a.md.GetFullyQualifiedName() != b.md.GetFullyQualifiedName() {
23 return false
24 }
25 if len(a.values) != len(b.values) {
26 return false
27 }
28 if len(a.unknownFields) != len(b.unknownFields) {
29 return false
30 }
31 for tag, aval := range a.values {
32 bval, ok := b.values[tag]
33 if !ok {
34 return false
35 }
36 if !fieldsEqual(aval, bval) {
37 return false
38 }
39 }
40 for tag, au := range a.unknownFields {
41 bu, ok := b.unknownFields[tag]
42 if !ok {
43 return false
44 }
45 if len(au) != len(bu) {
46 return false
47 }
48 for i, aval := range au {
49 bval := bu[i]
50 if aval.Encoding != bval.Encoding {
51 return false
52 }
53 if aval.Encoding == proto.WireBytes || aval.Encoding == proto.WireStartGroup {
54 if !bytes.Equal(aval.Contents, bval.Contents) {
55 return false
56 }
57 } else if aval.Value != bval.Value {
58 return false
59 }
60 }
61 }
62 // all checks pass!
63 return true
64}
65
66func fieldsEqual(aval, bval interface{}) bool {
67 arv := reflect.ValueOf(aval)
68 brv := reflect.ValueOf(bval)
69 if arv.Type() != brv.Type() {
70 // it is possible that one is a dynamic message and one is not
71 apm, ok := aval.(proto.Message)
72 if !ok {
73 return false
74 }
75 bpm, ok := bval.(proto.Message)
76 if !ok {
77 return false
78 }
79 return MessagesEqual(apm, bpm)
80
81 } else {
82 switch arv.Kind() {
83 case reflect.Ptr:
84 apm, ok := aval.(proto.Message)
85 if !ok {
86 // Don't know how to compare pointer values that aren't messages!
87 // Maybe this should panic?
88 return false
89 }
90 bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type
91 return MessagesEqual(apm, bpm)
92
93 case reflect.Map:
94 return mapsEqual(arv, brv)
95
96 case reflect.Slice:
97 if arv.Type() == typeOfBytes {
98 return bytes.Equal(aval.([]byte), bval.([]byte))
99 } else {
100 return slicesEqual(arv, brv)
101 }
102
103 default:
104 return aval == bval
105 }
106 }
107}
108
109func slicesEqual(a, b reflect.Value) bool {
110 if a.Len() != b.Len() {
111 return false
112 }
113 for i := 0; i < a.Len(); i++ {
114 ai := a.Index(i)
115 bi := b.Index(i)
116 if !fieldsEqual(ai.Interface(), bi.Interface()) {
117 return false
118 }
119 }
120 return true
121}
122
123// MessagesEqual returns true if the given two messages are equal. Use this instead of proto.Equal
124// when one or both of the messages might be a dynamic message.
125func MessagesEqual(a, b proto.Message) bool {
126 da, aok := a.(*Message)
127 db, bok := b.(*Message)
128 // Both dynamic messages
129 if aok && bok {
130 return Equal(da, db)
131 }
132 // Neither dynamic messages
133 if !aok && !bok {
134 return proto.Equal(a, b)
135 }
136 // Mixed
137 if bok {
138 // we want a to be the dynamic one
139 b, da = a, db
140 }
141
142 // Instead of panic'ing below if we have a nil dynamic message, check
143 // now and return false if the input message is not also nil.
144 if da == nil {
145 return isNil(b)
146 }
147
148 md, err := desc.LoadMessageDescriptorForMessage(b)
149 if err != nil {
150 return false
151 }
152 db = NewMessageWithMessageFactory(md, da.mf)
153 if db.ConvertFrom(b) != nil {
154 return false
155 }
156 return Equal(da, db)
157}