blob: 5fbcc245c45d8fe50359560f3f14a6e36b6c57db [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001package dynamic
2
3import (
4 "bytes"
5 "reflect"
6
7 "github.com/golang/protobuf/proto"
8
9 "github.com/jhump/protoreflect/desc"
10)
11
12// Equal returns true if the given two dynamic messages are equal. Two messages are equal when they
13// have the same message type and same fields set to equal values. For proto3 messages, fields set
14// to their zero value are considered unset.
15func Equal(a, b *Message) bool {
16 if a.md.GetFullyQualifiedName() != b.md.GetFullyQualifiedName() {
17 return false
18 }
19 if len(a.values) != len(b.values) {
20 return false
21 }
22 if len(a.unknownFields) != len(b.unknownFields) {
23 return false
24 }
25 for tag, aval := range a.values {
26 bval, ok := b.values[tag]
27 if !ok {
28 return false
29 }
30 if !fieldsEqual(aval, bval) {
31 return false
32 }
33 }
34 for tag, au := range a.unknownFields {
35 bu, ok := b.unknownFields[tag]
36 if !ok {
37 return false
38 }
39 if len(au) != len(bu) {
40 return false
41 }
42 for i, aval := range au {
43 bval := bu[i]
44 if aval.Encoding != bval.Encoding {
45 return false
46 }
47 if aval.Encoding == proto.WireBytes || aval.Encoding == proto.WireStartGroup {
48 if !bytes.Equal(aval.Contents, bval.Contents) {
49 return false
50 }
51 } else if aval.Value != bval.Value {
52 return false
53 }
54 }
55 }
56 // all checks pass!
57 return true
58}
59
60func fieldsEqual(aval, bval interface{}) bool {
61 arv := reflect.ValueOf(aval)
62 brv := reflect.ValueOf(bval)
63 if arv.Type() != brv.Type() {
64 // it is possible that one is a dynamic message and one is not
65 apm, ok := aval.(proto.Message)
66 if !ok {
67 return false
68 }
69 bpm, ok := bval.(proto.Message)
70 if !ok {
71 return false
72 }
73 return MessagesEqual(apm, bpm)
74
75 } else {
76 switch arv.Kind() {
77 case reflect.Ptr:
78 apm, ok := aval.(proto.Message)
79 if !ok {
80 // Don't know how to compare pointer values that aren't messages!
81 // Maybe this should panic?
82 return false
83 }
84 bpm := bval.(proto.Message) // we know it will succeed because we know a and b have same type
85 return MessagesEqual(apm, bpm)
86
87 case reflect.Map:
88 return mapsEqual(arv, brv)
89
90 case reflect.Slice:
91 if arv.Type() == typeOfBytes {
92 return bytes.Equal(aval.([]byte), bval.([]byte))
93 } else {
94 return slicesEqual(arv, brv)
95 }
96
97 default:
98 return aval == bval
99 }
100 }
101}
102
103func slicesEqual(a, b reflect.Value) bool {
104 if a.Len() != b.Len() {
105 return false
106 }
107 for i := 0; i < a.Len(); i++ {
108 ai := a.Index(i)
109 bi := b.Index(i)
110 if !fieldsEqual(ai.Interface(), bi.Interface()) {
111 return false
112 }
113 }
114 return true
115}
116
117// MessagesEqual returns true if the given two messages are equal. Use this instead of proto.Equal
118// when one or both of the messages might be a dynamic message.
119func MessagesEqual(a, b proto.Message) bool {
120 da, aok := a.(*Message)
121 db, bok := b.(*Message)
122 // Both dynamic messages
123 if aok && bok {
124 return Equal(da, db)
125 }
126 // Neither dynamic messages
127 if !aok && !bok {
128 return proto.Equal(a, b)
129 }
130 // Mixed
131 if aok {
132 md, err := desc.LoadMessageDescriptorForMessage(b)
133 if err != nil {
134 return false
135 }
136 db = NewMessageWithMessageFactory(md, da.mf)
137 if db.ConvertFrom(b) != nil {
138 return false
139 }
140 return Equal(da, db)
141 } else {
142 md, err := desc.LoadMessageDescriptorForMessage(a)
143 if err != nil {
144 return false
145 }
146 da = NewMessageWithMessageFactory(md, db.mf)
147 if da.ConvertFrom(a) != nil {
148 return false
149 }
150 return Equal(da, db)
151 }
152}