blob: 4dba2b969972908bd31d2f9e9de8fe9cfab3209b [file] [log] [blame]
David K. Bainbridgee05cf0c2021-08-19 03:16:50 +00001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
9 "math"
10 "reflect"
11
12 "google.golang.org/protobuf/encoding/protowire"
13 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16// Equal reports whether two messages are equal.
17// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
19//
20// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values. If either of the top-level
23// messages are invalid, then Equal reports true only if both are invalid.
24//
25// Scalar values are compared with the equivalent of the == operator in Go,
26// except bytes values which are compared using bytes.Equal and
27// floating point values which specially treat NaNs as equal.
28// Message values are compared by recursively calling Equal.
29// Lists are equal if each element value is also equal.
30// Maps are equal if they have the same set of keys, where the pair of values
31// for each key is also equal.
32func Equal(x, y Message) bool {
33 if x == nil || y == nil {
34 return x == nil && y == nil
35 }
36 mx := x.ProtoReflect()
37 my := y.ProtoReflect()
38 if mx.IsValid() != my.IsValid() {
39 return false
40 }
41 return equalMessage(mx, my)
42}
43
44// equalMessage compares two messages.
45func equalMessage(mx, my pref.Message) bool {
46 if mx.Descriptor() != my.Descriptor() {
47 return false
48 }
49
50 nx := 0
51 equal := true
52 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
53 nx++
54 vy := my.Get(fd)
55 equal = my.Has(fd) && equalField(fd, vx, vy)
56 return equal
57 })
58 if !equal {
59 return false
60 }
61 ny := 0
62 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
63 ny++
64 return true
65 })
66 if nx != ny {
67 return false
68 }
69
70 return equalUnknown(mx.GetUnknown(), my.GetUnknown())
71}
72
73// equalField compares two fields.
74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
75 switch {
76 case fd.IsList():
77 return equalList(fd, x.List(), y.List())
78 case fd.IsMap():
79 return equalMap(fd, x.Map(), y.Map())
80 default:
81 return equalValue(fd, x, y)
82 }
83}
84
85// equalMap compares two maps.
86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
87 if x.Len() != y.Len() {
88 return false
89 }
90 equal := true
91 x.Range(func(k pref.MapKey, vx pref.Value) bool {
92 vy := y.Get(k)
93 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
94 return equal
95 })
96 return equal
97}
98
99// equalList compares two lists.
100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
101 if x.Len() != y.Len() {
102 return false
103 }
104 for i := x.Len() - 1; i >= 0; i-- {
105 if !equalValue(fd, x.Get(i), y.Get(i)) {
106 return false
107 }
108 }
109 return true
110}
111
112// equalValue compares two singular values.
113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
114 switch fd.Kind() {
115 case pref.BoolKind:
116 return x.Bool() == y.Bool()
117 case pref.EnumKind:
118 return x.Enum() == y.Enum()
119 case pref.Int32Kind, pref.Sint32Kind,
120 pref.Int64Kind, pref.Sint64Kind,
121 pref.Sfixed32Kind, pref.Sfixed64Kind:
122 return x.Int() == y.Int()
123 case pref.Uint32Kind, pref.Uint64Kind,
124 pref.Fixed32Kind, pref.Fixed64Kind:
125 return x.Uint() == y.Uint()
126 case pref.FloatKind, pref.DoubleKind:
127 fx := x.Float()
128 fy := y.Float()
129 if math.IsNaN(fx) || math.IsNaN(fy) {
130 return math.IsNaN(fx) && math.IsNaN(fy)
131 }
132 return fx == fy
133 case pref.StringKind:
134 return x.String() == y.String()
135 case pref.BytesKind:
136 return bytes.Equal(x.Bytes(), y.Bytes())
137 case pref.MessageKind, pref.GroupKind:
138 return equalMessage(x.Message(), y.Message())
139 default:
140 return x.Interface() == y.Interface()
141 }
142}
143
144// equalUnknown compares unknown fields by direct comparison on the raw bytes
145// of each individual field number.
146func equalUnknown(x, y pref.RawFields) bool {
147 if len(x) != len(y) {
148 return false
149 }
150 if bytes.Equal([]byte(x), []byte(y)) {
151 return true
152 }
153
154 mx := make(map[pref.FieldNumber]pref.RawFields)
155 my := make(map[pref.FieldNumber]pref.RawFields)
156 for len(x) > 0 {
157 fnum, _, n := protowire.ConsumeField(x)
158 mx[fnum] = append(mx[fnum], x[:n]...)
159 x = x[n:]
160 }
161 for len(y) > 0 {
162 fnum, _, n := protowire.ConsumeField(y)
163 my[fnum] = append(my[fnum], y[:n]...)
164 y = y[n:]
165 }
166 return reflect.DeepEqual(mx, my)
167}