sslobodr | d046be8 | 2019-01-16 10:02:22 -0500 | [diff] [blame] | 1 | // Copyright 2009 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | // Package reflect is a fork of go's standard library reflection package, which |
| 6 | // allows for deep equal with equality functions defined. |
| 7 | package reflect |
| 8 | |
| 9 | import ( |
| 10 | "fmt" |
| 11 | "reflect" |
| 12 | "strings" |
| 13 | ) |
| 14 | |
| 15 | // Equalities is a map from type to a function comparing two values of |
| 16 | // that type. |
| 17 | type Equalities map[reflect.Type]reflect.Value |
| 18 | |
| 19 | // For convenience, panics on errrors |
| 20 | func EqualitiesOrDie(funcs ...interface{}) Equalities { |
| 21 | e := Equalities{} |
| 22 | if err := e.AddFuncs(funcs...); err != nil { |
| 23 | panic(err) |
| 24 | } |
| 25 | return e |
| 26 | } |
| 27 | |
| 28 | // AddFuncs is a shortcut for multiple calls to AddFunc. |
| 29 | func (e Equalities) AddFuncs(funcs ...interface{}) error { |
| 30 | for _, f := range funcs { |
| 31 | if err := e.AddFunc(f); err != nil { |
| 32 | return err |
| 33 | } |
| 34 | } |
| 35 | return nil |
| 36 | } |
| 37 | |
| 38 | // AddFunc uses func as an equality function: it must take |
| 39 | // two parameters of the same type, and return a boolean. |
| 40 | func (e Equalities) AddFunc(eqFunc interface{}) error { |
| 41 | fv := reflect.ValueOf(eqFunc) |
| 42 | ft := fv.Type() |
| 43 | if ft.Kind() != reflect.Func { |
| 44 | return fmt.Errorf("expected func, got: %v", ft) |
| 45 | } |
| 46 | if ft.NumIn() != 2 { |
| 47 | return fmt.Errorf("expected two 'in' params, got: %v", ft) |
| 48 | } |
| 49 | if ft.NumOut() != 1 { |
| 50 | return fmt.Errorf("expected one 'out' param, got: %v", ft) |
| 51 | } |
| 52 | if ft.In(0) != ft.In(1) { |
| 53 | return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft) |
| 54 | } |
| 55 | var forReturnType bool |
| 56 | boolType := reflect.TypeOf(forReturnType) |
| 57 | if ft.Out(0) != boolType { |
| 58 | return fmt.Errorf("expected bool return, got: %v", ft) |
| 59 | } |
| 60 | e[ft.In(0)] = fv |
| 61 | return nil |
| 62 | } |
| 63 | |
| 64 | // Below here is forked from go's reflect/deepequal.go |
| 65 | |
| 66 | // During deepValueEqual, must keep track of checks that are |
| 67 | // in progress. The comparison algorithm assumes that all |
| 68 | // checks in progress are true when it reencounters them. |
| 69 | // Visited comparisons are stored in a map indexed by visit. |
| 70 | type visit struct { |
| 71 | a1 uintptr |
| 72 | a2 uintptr |
| 73 | typ reflect.Type |
| 74 | } |
| 75 | |
| 76 | // unexportedTypePanic is thrown when you use this DeepEqual on something that has an |
| 77 | // unexported type. It indicates a programmer error, so should not occur at runtime, |
| 78 | // which is why it's not public and thus impossible to catch. |
| 79 | type unexportedTypePanic []reflect.Type |
| 80 | |
| 81 | func (u unexportedTypePanic) Error() string { return u.String() } |
| 82 | func (u unexportedTypePanic) String() string { |
| 83 | strs := make([]string, len(u)) |
| 84 | for i, t := range u { |
| 85 | strs[i] = fmt.Sprintf("%v", t) |
| 86 | } |
| 87 | return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ") |
| 88 | } |
| 89 | |
| 90 | func makeUsefulPanic(v reflect.Value) { |
| 91 | if x := recover(); x != nil { |
| 92 | if u, ok := x.(unexportedTypePanic); ok { |
| 93 | u = append(unexportedTypePanic{v.Type()}, u...) |
| 94 | x = u |
| 95 | } |
| 96 | panic(x) |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | // Tests for deep equality using reflected types. The map argument tracks |
| 101 | // comparisons that have already been seen, which allows short circuiting on |
| 102 | // recursive types. |
| 103 | func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { |
| 104 | defer makeUsefulPanic(v1) |
| 105 | |
| 106 | if !v1.IsValid() || !v2.IsValid() { |
| 107 | return v1.IsValid() == v2.IsValid() |
| 108 | } |
| 109 | if v1.Type() != v2.Type() { |
| 110 | return false |
| 111 | } |
| 112 | if fv, ok := e[v1.Type()]; ok { |
| 113 | return fv.Call([]reflect.Value{v1, v2})[0].Bool() |
| 114 | } |
| 115 | |
| 116 | hard := func(k reflect.Kind) bool { |
| 117 | switch k { |
| 118 | case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: |
| 119 | return true |
| 120 | } |
| 121 | return false |
| 122 | } |
| 123 | |
| 124 | if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { |
| 125 | addr1 := v1.UnsafeAddr() |
| 126 | addr2 := v2.UnsafeAddr() |
| 127 | if addr1 > addr2 { |
| 128 | // Canonicalize order to reduce number of entries in visited. |
| 129 | addr1, addr2 = addr2, addr1 |
| 130 | } |
| 131 | |
| 132 | // Short circuit if references are identical ... |
| 133 | if addr1 == addr2 { |
| 134 | return true |
| 135 | } |
| 136 | |
| 137 | // ... or already seen |
| 138 | typ := v1.Type() |
| 139 | v := visit{addr1, addr2, typ} |
| 140 | if visited[v] { |
| 141 | return true |
| 142 | } |
| 143 | |
| 144 | // Remember for later. |
| 145 | visited[v] = true |
| 146 | } |
| 147 | |
| 148 | switch v1.Kind() { |
| 149 | case reflect.Array: |
| 150 | // We don't need to check length here because length is part of |
| 151 | // an array's type, which has already been filtered for. |
| 152 | for i := 0; i < v1.Len(); i++ { |
| 153 | if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { |
| 154 | return false |
| 155 | } |
| 156 | } |
| 157 | return true |
| 158 | case reflect.Slice: |
| 159 | if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) { |
| 160 | return false |
| 161 | } |
| 162 | if v1.IsNil() || v1.Len() == 0 { |
| 163 | return true |
| 164 | } |
| 165 | if v1.Len() != v2.Len() { |
| 166 | return false |
| 167 | } |
| 168 | if v1.Pointer() == v2.Pointer() { |
| 169 | return true |
| 170 | } |
| 171 | for i := 0; i < v1.Len(); i++ { |
| 172 | if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) { |
| 173 | return false |
| 174 | } |
| 175 | } |
| 176 | return true |
| 177 | case reflect.Interface: |
| 178 | if v1.IsNil() || v2.IsNil() { |
| 179 | return v1.IsNil() == v2.IsNil() |
| 180 | } |
| 181 | return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) |
| 182 | case reflect.Ptr: |
| 183 | return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) |
| 184 | case reflect.Struct: |
| 185 | for i, n := 0, v1.NumField(); i < n; i++ { |
| 186 | if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) { |
| 187 | return false |
| 188 | } |
| 189 | } |
| 190 | return true |
| 191 | case reflect.Map: |
| 192 | if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) { |
| 193 | return false |
| 194 | } |
| 195 | if v1.IsNil() || v1.Len() == 0 { |
| 196 | return true |
| 197 | } |
| 198 | if v1.Len() != v2.Len() { |
| 199 | return false |
| 200 | } |
| 201 | if v1.Pointer() == v2.Pointer() { |
| 202 | return true |
| 203 | } |
| 204 | for _, k := range v1.MapKeys() { |
| 205 | if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { |
| 206 | return false |
| 207 | } |
| 208 | } |
| 209 | return true |
| 210 | case reflect.Func: |
| 211 | if v1.IsNil() && v2.IsNil() { |
| 212 | return true |
| 213 | } |
| 214 | // Can't do better than this: |
| 215 | return false |
| 216 | default: |
| 217 | // Normal equality suffices |
| 218 | if !v1.CanInterface() || !v2.CanInterface() { |
| 219 | panic(unexportedTypePanic{}) |
| 220 | } |
| 221 | return v1.Interface() == v2.Interface() |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | // DeepEqual is like reflect.DeepEqual, but focused on semantic equality |
| 226 | // instead of memory equality. |
| 227 | // |
| 228 | // It will use e's equality functions if it finds types that match. |
| 229 | // |
| 230 | // An empty slice *is* equal to a nil slice for our purposes; same for maps. |
| 231 | // |
| 232 | // Unexported field members cannot be compared and will cause an imformative panic; you must add an Equality |
| 233 | // function for these types. |
| 234 | func (e Equalities) DeepEqual(a1, a2 interface{}) bool { |
| 235 | if a1 == nil || a2 == nil { |
| 236 | return a1 == a2 |
| 237 | } |
| 238 | v1 := reflect.ValueOf(a1) |
| 239 | v2 := reflect.ValueOf(a2) |
| 240 | if v1.Type() != v2.Type() { |
| 241 | return false |
| 242 | } |
| 243 | return e.deepValueEqual(v1, v2, make(map[visit]bool), 0) |
| 244 | } |
| 245 | |
| 246 | func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool { |
| 247 | defer makeUsefulPanic(v1) |
| 248 | |
| 249 | if !v1.IsValid() || !v2.IsValid() { |
| 250 | return v1.IsValid() == v2.IsValid() |
| 251 | } |
| 252 | if v1.Type() != v2.Type() { |
| 253 | return false |
| 254 | } |
| 255 | if fv, ok := e[v1.Type()]; ok { |
| 256 | return fv.Call([]reflect.Value{v1, v2})[0].Bool() |
| 257 | } |
| 258 | |
| 259 | hard := func(k reflect.Kind) bool { |
| 260 | switch k { |
| 261 | case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: |
| 262 | return true |
| 263 | } |
| 264 | return false |
| 265 | } |
| 266 | |
| 267 | if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { |
| 268 | addr1 := v1.UnsafeAddr() |
| 269 | addr2 := v2.UnsafeAddr() |
| 270 | if addr1 > addr2 { |
| 271 | // Canonicalize order to reduce number of entries in visited. |
| 272 | addr1, addr2 = addr2, addr1 |
| 273 | } |
| 274 | |
| 275 | // Short circuit if references are identical ... |
| 276 | if addr1 == addr2 { |
| 277 | return true |
| 278 | } |
| 279 | |
| 280 | // ... or already seen |
| 281 | typ := v1.Type() |
| 282 | v := visit{addr1, addr2, typ} |
| 283 | if visited[v] { |
| 284 | return true |
| 285 | } |
| 286 | |
| 287 | // Remember for later. |
| 288 | visited[v] = true |
| 289 | } |
| 290 | |
| 291 | switch v1.Kind() { |
| 292 | case reflect.Array: |
| 293 | // We don't need to check length here because length is part of |
| 294 | // an array's type, which has already been filtered for. |
| 295 | for i := 0; i < v1.Len(); i++ { |
| 296 | if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) { |
| 297 | return false |
| 298 | } |
| 299 | } |
| 300 | return true |
| 301 | case reflect.Slice: |
| 302 | if v1.IsNil() || v1.Len() == 0 { |
| 303 | return true |
| 304 | } |
| 305 | if v1.Len() > v2.Len() { |
| 306 | return false |
| 307 | } |
| 308 | if v1.Pointer() == v2.Pointer() { |
| 309 | return true |
| 310 | } |
| 311 | for i := 0; i < v1.Len(); i++ { |
| 312 | if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) { |
| 313 | return false |
| 314 | } |
| 315 | } |
| 316 | return true |
| 317 | case reflect.String: |
| 318 | if v1.Len() == 0 { |
| 319 | return true |
| 320 | } |
| 321 | if v1.Len() > v2.Len() { |
| 322 | return false |
| 323 | } |
| 324 | return v1.String() == v2.String() |
| 325 | case reflect.Interface: |
| 326 | if v1.IsNil() { |
| 327 | return true |
| 328 | } |
| 329 | return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1) |
| 330 | case reflect.Ptr: |
| 331 | if v1.IsNil() { |
| 332 | return true |
| 333 | } |
| 334 | return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1) |
| 335 | case reflect.Struct: |
| 336 | for i, n := 0, v1.NumField(); i < n; i++ { |
| 337 | if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) { |
| 338 | return false |
| 339 | } |
| 340 | } |
| 341 | return true |
| 342 | case reflect.Map: |
| 343 | if v1.IsNil() || v1.Len() == 0 { |
| 344 | return true |
| 345 | } |
| 346 | if v1.Len() > v2.Len() { |
| 347 | return false |
| 348 | } |
| 349 | if v1.Pointer() == v2.Pointer() { |
| 350 | return true |
| 351 | } |
| 352 | for _, k := range v1.MapKeys() { |
| 353 | if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) { |
| 354 | return false |
| 355 | } |
| 356 | } |
| 357 | return true |
| 358 | case reflect.Func: |
| 359 | if v1.IsNil() && v2.IsNil() { |
| 360 | return true |
| 361 | } |
| 362 | // Can't do better than this: |
| 363 | return false |
| 364 | default: |
| 365 | // Normal equality suffices |
| 366 | if !v1.CanInterface() || !v2.CanInterface() { |
| 367 | panic(unexportedTypePanic{}) |
| 368 | } |
| 369 | return v1.Interface() == v2.Interface() |
| 370 | } |
| 371 | } |
| 372 | |
| 373 | // DeepDerivative is similar to DeepEqual except that unset fields in a1 are |
| 374 | // ignored (not compared). This allows us to focus on the fields that matter to |
| 375 | // the semantic comparison. |
| 376 | // |
| 377 | // The unset fields include a nil pointer and an empty string. |
| 378 | func (e Equalities) DeepDerivative(a1, a2 interface{}) bool { |
| 379 | if a1 == nil { |
| 380 | return true |
| 381 | } |
| 382 | v1 := reflect.ValueOf(a1) |
| 383 | v2 := reflect.ValueOf(a2) |
| 384 | if v1.Type() != v2.Type() { |
| 385 | return false |
| 386 | } |
| 387 | return e.deepValueDerive(v1, v2, make(map[visit]bool), 0) |
| 388 | } |