blob: 10902bd851eb4724eec344a09f3d55aac3a10f84 [file] [log] [blame]
Matteo Scandoloa4285862020-12-01 18:10:10 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
9 "math"
10 "reflect"
11
12 "google.golang.org/protobuf/encoding/protowire"
13 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16// Equal reports whether two messages are equal.
17// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
19//
20// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values. If either of the top-level
23// messages are invalid, then Equal reports true only if both are invalid.
24//
25// Scalar values are compared with the equivalent of the == operator in Go,
26// except bytes values which are compared using bytes.Equal and
27// floating point values which specially treat NaNs as equal.
28// Message values are compared by recursively calling Equal.
29// Lists are equal if each element value is also equal.
30// Maps are equal if they have the same set of keys, where the pair of values
31// for each key is also equal.
32func Equal(x, y Message) bool {
33 if x == nil || y == nil {
34 return x == nil && y == nil
35 }
36 mx := x.ProtoReflect()
37 my := y.ProtoReflect()
38 if mx.IsValid() != my.IsValid() {
39 return false
40 }
41 return equalMessage(mx, my)
42}
43
44// equalMessage compares two messages.
45func equalMessage(mx, my pref.Message) bool {
46 if mx.Descriptor() != my.Descriptor() {
47 return false
48 }
49
50 nx := 0
51 equal := true
52 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
53 nx++
54 vy := my.Get(fd)
55 equal = my.Has(fd) && equalField(fd, vx, vy)
56 return equal
57 })
58 if !equal {
59 return false
60 }
61 ny := 0
62 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
63 ny++
64 return true
65 })
66 if nx != ny {
67 return false
68 }
69
70 return equalUnknown(mx.GetUnknown(), my.GetUnknown())
71}
72
73// equalField compares two fields.
74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
75 switch {
76 case fd.IsList():
77 return equalList(fd, x.List(), y.List())
78 case fd.IsMap():
79 return equalMap(fd, x.Map(), y.Map())
80 default:
81 return equalValue(fd, x, y)
82 }
83}
84
85// equalMap compares two maps.
86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
87 if x.Len() != y.Len() {
88 return false
89 }
90 equal := true
91 x.Range(func(k pref.MapKey, vx pref.Value) bool {
92 vy := y.Get(k)
93 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
94 return equal
95 })
96 return equal
97}
98
99// equalList compares two lists.
100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
101 if x.Len() != y.Len() {
102 return false
103 }
104 for i := x.Len() - 1; i >= 0; i-- {
105 if !equalValue(fd, x.Get(i), y.Get(i)) {
106 return false
107 }
108 }
109 return true
110}
111
112// equalValue compares two singular values.
113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
114 switch {
115 case fd.Message() != nil:
116 return equalMessage(x.Message(), y.Message())
117 case fd.Kind() == pref.BytesKind:
118 return bytes.Equal(x.Bytes(), y.Bytes())
119 case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
120 fx := x.Float()
121 fy := y.Float()
122 if math.IsNaN(fx) || math.IsNaN(fy) {
123 return math.IsNaN(fx) && math.IsNaN(fy)
124 }
125 return fx == fy
126 default:
127 return x.Interface() == y.Interface()
128 }
129}
130
131// equalUnknown compares unknown fields by direct comparison on the raw bytes
132// of each individual field number.
133func equalUnknown(x, y pref.RawFields) bool {
134 if len(x) != len(y) {
135 return false
136 }
137 if bytes.Equal([]byte(x), []byte(y)) {
138 return true
139 }
140
141 mx := make(map[pref.FieldNumber]pref.RawFields)
142 my := make(map[pref.FieldNumber]pref.RawFields)
143 for len(x) > 0 {
144 fnum, _, n := protowire.ConsumeField(x)
145 mx[fnum] = append(mx[fnum], x[:n]...)
146 x = x[n:]
147 }
148 for len(y) > 0 {
149 fnum, _, n := protowire.ConsumeField(y)
150 my[fnum] = append(my[fnum], y[:n]...)
151 y = y[n:]
152 }
153 return reflect.DeepEqual(mx, my)
154}