blob: 7ed1d1cffec7d4cd723c9b910e96f7542eb73066 [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package reflect is a fork of go's standard library reflection package, which
6// allows for deep equal with equality functions defined.
7package reflect
8
9import (
10 "fmt"
11 "reflect"
12 "strings"
13)
14
15// Equalities is a map from type to a function comparing two values of
16// that type.
17type Equalities map[reflect.Type]reflect.Value
18
19// For convenience, panics on errrors
20func EqualitiesOrDie(funcs ...interface{}) Equalities {
21 e := Equalities{}
22 if err := e.AddFuncs(funcs...); err != nil {
23 panic(err)
24 }
25 return e
26}
27
28// AddFuncs is a shortcut for multiple calls to AddFunc.
29func (e Equalities) AddFuncs(funcs ...interface{}) error {
30 for _, f := range funcs {
31 if err := e.AddFunc(f); err != nil {
32 return err
33 }
34 }
35 return nil
36}
37
38// AddFunc uses func as an equality function: it must take
39// two parameters of the same type, and return a boolean.
40func (e Equalities) AddFunc(eqFunc interface{}) error {
41 fv := reflect.ValueOf(eqFunc)
42 ft := fv.Type()
43 if ft.Kind() != reflect.Func {
44 return fmt.Errorf("expected func, got: %v", ft)
45 }
46 if ft.NumIn() != 2 {
47 return fmt.Errorf("expected two 'in' params, got: %v", ft)
48 }
49 if ft.NumOut() != 1 {
50 return fmt.Errorf("expected one 'out' param, got: %v", ft)
51 }
52 if ft.In(0) != ft.In(1) {
53 return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
54 }
55 var forReturnType bool
56 boolType := reflect.TypeOf(forReturnType)
57 if ft.Out(0) != boolType {
58 return fmt.Errorf("expected bool return, got: %v", ft)
59 }
60 e[ft.In(0)] = fv
61 return nil
62}
63
64// Below here is forked from go's reflect/deepequal.go
65
66// During deepValueEqual, must keep track of checks that are
67// in progress. The comparison algorithm assumes that all
68// checks in progress are true when it reencounters them.
69// Visited comparisons are stored in a map indexed by visit.
70type visit struct {
71 a1 uintptr
72 a2 uintptr
73 typ reflect.Type
74}
75
76// unexportedTypePanic is thrown when you use this DeepEqual on something that has an
77// unexported type. It indicates a programmer error, so should not occur at runtime,
78// which is why it's not public and thus impossible to catch.
79type unexportedTypePanic []reflect.Type
80
81func (u unexportedTypePanic) Error() string { return u.String() }
82func (u unexportedTypePanic) String() string {
83 strs := make([]string, len(u))
84 for i, t := range u {
85 strs[i] = fmt.Sprintf("%v", t)
86 }
87 return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
88}
89
90func makeUsefulPanic(v reflect.Value) {
91 if x := recover(); x != nil {
92 if u, ok := x.(unexportedTypePanic); ok {
93 u = append(unexportedTypePanic{v.Type()}, u...)
94 x = u
95 }
96 panic(x)
97 }
98}
99
100// Tests for deep equality using reflected types. The map argument tracks
101// comparisons that have already been seen, which allows short circuiting on
102// recursive types.
103func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
104 defer makeUsefulPanic(v1)
105
106 if !v1.IsValid() || !v2.IsValid() {
107 return v1.IsValid() == v2.IsValid()
108 }
109 if v1.Type() != v2.Type() {
110 return false
111 }
112 if fv, ok := e[v1.Type()]; ok {
113 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
114 }
115
116 hard := func(k reflect.Kind) bool {
117 switch k {
118 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
119 return true
120 }
121 return false
122 }
123
124 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
125 addr1 := v1.UnsafeAddr()
126 addr2 := v2.UnsafeAddr()
127 if addr1 > addr2 {
128 // Canonicalize order to reduce number of entries in visited.
129 addr1, addr2 = addr2, addr1
130 }
131
132 // Short circuit if references are identical ...
133 if addr1 == addr2 {
134 return true
135 }
136
137 // ... or already seen
138 typ := v1.Type()
139 v := visit{addr1, addr2, typ}
140 if visited[v] {
141 return true
142 }
143
144 // Remember for later.
145 visited[v] = true
146 }
147
148 switch v1.Kind() {
149 case reflect.Array:
150 // We don't need to check length here because length is part of
151 // an array's type, which has already been filtered for.
152 for i := 0; i < v1.Len(); i++ {
153 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
154 return false
155 }
156 }
157 return true
158 case reflect.Slice:
159 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
160 return false
161 }
162 if v1.IsNil() || v1.Len() == 0 {
163 return true
164 }
165 if v1.Len() != v2.Len() {
166 return false
167 }
168 if v1.Pointer() == v2.Pointer() {
169 return true
170 }
171 for i := 0; i < v1.Len(); i++ {
172 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
173 return false
174 }
175 }
176 return true
177 case reflect.Interface:
178 if v1.IsNil() || v2.IsNil() {
179 return v1.IsNil() == v2.IsNil()
180 }
181 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
182 case reflect.Ptr:
183 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
184 case reflect.Struct:
185 for i, n := 0, v1.NumField(); i < n; i++ {
186 if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
187 return false
188 }
189 }
190 return true
191 case reflect.Map:
192 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
193 return false
194 }
195 if v1.IsNil() || v1.Len() == 0 {
196 return true
197 }
198 if v1.Len() != v2.Len() {
199 return false
200 }
201 if v1.Pointer() == v2.Pointer() {
202 return true
203 }
204 for _, k := range v1.MapKeys() {
205 if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
206 return false
207 }
208 }
209 return true
210 case reflect.Func:
211 if v1.IsNil() && v2.IsNil() {
212 return true
213 }
214 // Can't do better than this:
215 return false
216 default:
217 // Normal equality suffices
218 if !v1.CanInterface() || !v2.CanInterface() {
219 panic(unexportedTypePanic{})
220 }
221 return v1.Interface() == v2.Interface()
222 }
223}
224
225// DeepEqual is like reflect.DeepEqual, but focused on semantic equality
226// instead of memory equality.
227//
228// It will use e's equality functions if it finds types that match.
229//
230// An empty slice *is* equal to a nil slice for our purposes; same for maps.
231//
232// Unexported field members cannot be compared and will cause an imformative panic; you must add an Equality
233// function for these types.
234func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
235 if a1 == nil || a2 == nil {
236 return a1 == a2
237 }
238 v1 := reflect.ValueOf(a1)
239 v2 := reflect.ValueOf(a2)
240 if v1.Type() != v2.Type() {
241 return false
242 }
243 return e.deepValueEqual(v1, v2, make(map[visit]bool), 0)
244}
245
246func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
247 defer makeUsefulPanic(v1)
248
249 if !v1.IsValid() || !v2.IsValid() {
250 return v1.IsValid() == v2.IsValid()
251 }
252 if v1.Type() != v2.Type() {
253 return false
254 }
255 if fv, ok := e[v1.Type()]; ok {
256 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
257 }
258
259 hard := func(k reflect.Kind) bool {
260 switch k {
261 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
262 return true
263 }
264 return false
265 }
266
267 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
268 addr1 := v1.UnsafeAddr()
269 addr2 := v2.UnsafeAddr()
270 if addr1 > addr2 {
271 // Canonicalize order to reduce number of entries in visited.
272 addr1, addr2 = addr2, addr1
273 }
274
275 // Short circuit if references are identical ...
276 if addr1 == addr2 {
277 return true
278 }
279
280 // ... or already seen
281 typ := v1.Type()
282 v := visit{addr1, addr2, typ}
283 if visited[v] {
284 return true
285 }
286
287 // Remember for later.
288 visited[v] = true
289 }
290
291 switch v1.Kind() {
292 case reflect.Array:
293 // We don't need to check length here because length is part of
294 // an array's type, which has already been filtered for.
295 for i := 0; i < v1.Len(); i++ {
296 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
297 return false
298 }
299 }
300 return true
301 case reflect.Slice:
302 if v1.IsNil() || v1.Len() == 0 {
303 return true
304 }
305 if v1.Len() > v2.Len() {
306 return false
307 }
308 if v1.Pointer() == v2.Pointer() {
309 return true
310 }
311 for i := 0; i < v1.Len(); i++ {
312 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
313 return false
314 }
315 }
316 return true
317 case reflect.String:
318 if v1.Len() == 0 {
319 return true
320 }
321 if v1.Len() > v2.Len() {
322 return false
323 }
324 return v1.String() == v2.String()
325 case reflect.Interface:
326 if v1.IsNil() {
327 return true
328 }
329 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
330 case reflect.Ptr:
331 if v1.IsNil() {
332 return true
333 }
334 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
335 case reflect.Struct:
336 for i, n := 0, v1.NumField(); i < n; i++ {
337 if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) {
338 return false
339 }
340 }
341 return true
342 case reflect.Map:
343 if v1.IsNil() || v1.Len() == 0 {
344 return true
345 }
346 if v1.Len() > v2.Len() {
347 return false
348 }
349 if v1.Pointer() == v2.Pointer() {
350 return true
351 }
352 for _, k := range v1.MapKeys() {
353 if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
354 return false
355 }
356 }
357 return true
358 case reflect.Func:
359 if v1.IsNil() && v2.IsNil() {
360 return true
361 }
362 // Can't do better than this:
363 return false
364 default:
365 // Normal equality suffices
366 if !v1.CanInterface() || !v2.CanInterface() {
367 panic(unexportedTypePanic{})
368 }
369 return v1.Interface() == v2.Interface()
370 }
371}
372
373// DeepDerivative is similar to DeepEqual except that unset fields in a1 are
374// ignored (not compared). This allows us to focus on the fields that matter to
375// the semantic comparison.
376//
377// The unset fields include a nil pointer and an empty string.
378func (e Equalities) DeepDerivative(a1, a2 interface{}) bool {
379 if a1 == nil {
380 return true
381 }
382 v1 := reflect.ValueOf(a1)
383 v2 := reflect.ValueOf(a2)
384 if v1.Type() != v2.Type() {
385 return false
386 }
387 return e.deepValueDerive(v1, v2, make(map[visit]bool), 0)
388}