blob: cf48d887acea88257a6a6fe2f8e04cb4dfc2fe26 [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001/*Package cmp provides Comparisons for Assert and Check*/
2package cmp // import "gotest.tools/assert/cmp"
3
4import (
5 "fmt"
6 "reflect"
7 "regexp"
8 "strings"
9
10 "github.com/google/go-cmp/cmp"
11 "gotest.tools/internal/format"
12)
13
14// Comparison is a function which compares values and returns ResultSuccess if
15// the actual value matches the expected value. If the values do not match the
16// Result will contain a message about why it failed.
17type Comparison func() Result
18
19// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
20// and succeeds if the values are equal.
21//
22// The comparison can be customized using comparison Options.
23// Package https://godoc.org/gotest.tools/assert/opt provides some additional
24// commonly used Options.
25func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
26 return func() (result Result) {
27 defer func() {
28 if panicmsg, handled := handleCmpPanic(recover()); handled {
29 result = ResultFailure(panicmsg)
30 }
31 }()
32 diff := cmp.Diff(x, y, opts...)
33 if diff == "" {
34 return ResultSuccess
35 }
36 return multiLineDiffResult(diff)
37 }
38}
39
40func handleCmpPanic(r interface{}) (string, bool) {
41 if r == nil {
42 return "", false
43 }
44 panicmsg, ok := r.(string)
45 if !ok {
46 panic(r)
47 }
48 switch {
49 case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
50 return panicmsg, true
51 }
52 panic(r)
53}
54
55func toResult(success bool, msg string) Result {
56 if success {
57 return ResultSuccess
58 }
59 return ResultFailure(msg)
60}
61
62// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
63// regexp pattern.
64type RegexOrPattern interface{}
65
66// Regexp succeeds if value v matches regular expression re.
67//
68// Example:
69// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
70// r := regexp.MustCompile("^[0-9a-f]{32}$")
71// assert.Assert(t, cmp.Regexp(r, str))
72func Regexp(re RegexOrPattern, v string) Comparison {
73 match := func(re *regexp.Regexp) Result {
74 return toResult(
75 re.MatchString(v),
76 fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
77 }
78
79 return func() Result {
80 switch regex := re.(type) {
81 case *regexp.Regexp:
82 return match(regex)
83 case string:
84 re, err := regexp.Compile(regex)
85 if err != nil {
86 return ResultFailure(err.Error())
87 }
88 return match(re)
89 default:
90 return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex))
91 }
92 }
93}
94
95// Equal succeeds if x == y. See assert.Equal for full documentation.
96func Equal(x, y interface{}) Comparison {
97 return func() Result {
98 switch {
99 case x == y:
100 return ResultSuccess
101 case isMultiLineStringCompare(x, y):
102 diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
103 return multiLineDiffResult(diff)
104 }
105 return ResultFailureTemplate(`
106 {{- .Data.x}} (
107 {{- with callArg 0 }}{{ formatNode . }} {{end -}}
108 {{- printf "%T" .Data.x -}}
109 ) != {{ .Data.y}} (
110 {{- with callArg 1 }}{{ formatNode . }} {{end -}}
111 {{- printf "%T" .Data.y -}}
112 )`,
113 map[string]interface{}{"x": x, "y": y})
114 }
115}
116
117func isMultiLineStringCompare(x, y interface{}) bool {
118 strX, ok := x.(string)
119 if !ok {
120 return false
121 }
122 strY, ok := y.(string)
123 if !ok {
124 return false
125 }
126 return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
127}
128
129func multiLineDiffResult(diff string) Result {
130 return ResultFailureTemplate(`
131--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
132+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
133{{ .Data.diff }}`,
134 map[string]interface{}{"diff": diff})
135}
136
137// Len succeeds if the sequence has the expected length.
138func Len(seq interface{}, expected int) Comparison {
139 return func() (result Result) {
140 defer func() {
141 if e := recover(); e != nil {
142 result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
143 }
144 }()
145 value := reflect.ValueOf(seq)
146 length := value.Len()
147 if length == expected {
148 return ResultSuccess
149 }
150 msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
151 return ResultFailure(msg)
152 }
153}
154
155// Contains succeeds if item is in collection. Collection may be a string, map,
156// slice, or array.
157//
158// If collection is a string, item must also be a string, and is compared using
159// strings.Contains().
160// If collection is a Map, contains will succeed if item is a key in the map.
161// If collection is a slice or array, item is compared to each item in the
162// sequence using reflect.DeepEqual().
163func Contains(collection interface{}, item interface{}) Comparison {
164 return func() Result {
165 colValue := reflect.ValueOf(collection)
166 if !colValue.IsValid() {
167 return ResultFailure(fmt.Sprintf("nil does not contain items"))
168 }
169 msg := fmt.Sprintf("%v does not contain %v", collection, item)
170
171 itemValue := reflect.ValueOf(item)
172 switch colValue.Type().Kind() {
173 case reflect.String:
174 if itemValue.Type().Kind() != reflect.String {
175 return ResultFailure("string may only contain strings")
176 }
177 return toResult(
178 strings.Contains(colValue.String(), itemValue.String()),
179 fmt.Sprintf("string %q does not contain %q", collection, item))
180
181 case reflect.Map:
182 if itemValue.Type() != colValue.Type().Key() {
183 return ResultFailure(fmt.Sprintf(
184 "%v can not contain a %v key", colValue.Type(), itemValue.Type()))
185 }
186 return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
187
188 case reflect.Slice, reflect.Array:
189 for i := 0; i < colValue.Len(); i++ {
190 if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
191 return ResultSuccess
192 }
193 }
194 return ResultFailure(msg)
195 default:
196 return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
197 }
198 }
199}
200
201// Panics succeeds if f() panics.
202func Panics(f func()) Comparison {
203 return func() (result Result) {
204 defer func() {
205 if err := recover(); err != nil {
206 result = ResultSuccess
207 }
208 }()
209 f()
210 return ResultFailure("did not panic")
211 }
212}
213
214// Error succeeds if err is a non-nil error, and the error message equals the
215// expected message.
216func Error(err error, message string) Comparison {
217 return func() Result {
218 switch {
219 case err == nil:
220 return ResultFailure("expected an error, got nil")
221 case err.Error() != message:
222 return ResultFailure(fmt.Sprintf(
223 "expected error %q, got %s", message, formatErrorMessage(err)))
224 }
225 return ResultSuccess
226 }
227}
228
229// ErrorContains succeeds if err is a non-nil error, and the error message contains
230// the expected substring.
231func ErrorContains(err error, substring string) Comparison {
232 return func() Result {
233 switch {
234 case err == nil:
235 return ResultFailure("expected an error, got nil")
236 case !strings.Contains(err.Error(), substring):
237 return ResultFailure(fmt.Sprintf(
238 "expected error to contain %q, got %s", substring, formatErrorMessage(err)))
239 }
240 return ResultSuccess
241 }
242}
243
244func formatErrorMessage(err error) string {
245 if _, ok := err.(interface {
246 Cause() error
247 }); ok {
248 return fmt.Sprintf("%q\n%+v", err, err)
249 }
250 // This error was not wrapped with github.com/pkg/errors
251 return fmt.Sprintf("%q", err)
252}
253
254// Nil succeeds if obj is a nil interface, pointer, or function.
255//
256// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
257// maps, and channels.
258func Nil(obj interface{}) Comparison {
259 msgFunc := func(value reflect.Value) string {
260 return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
261 }
262 return isNil(obj, msgFunc)
263}
264
265func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
266 return func() Result {
267 if obj == nil {
268 return ResultSuccess
269 }
270 value := reflect.ValueOf(obj)
271 kind := value.Type().Kind()
272 if kind >= reflect.Chan && kind <= reflect.Slice {
273 if value.IsNil() {
274 return ResultSuccess
275 }
276 return ResultFailure(msgFunc(value))
277 }
278
279 return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
280 }
281}
282
283// ErrorType succeeds if err is not nil and is of the expected type.
284//
285// Expected can be one of:
286// a func(error) bool which returns true if the error is the expected type,
287// an instance of (or a pointer to) a struct of the expected type,
288// a pointer to an interface the error is expected to implement,
289// a reflect.Type of the expected struct or interface.
290func ErrorType(err error, expected interface{}) Comparison {
291 return func() Result {
292 switch expectedType := expected.(type) {
293 case func(error) bool:
294 return cmpErrorTypeFunc(err, expectedType)
295 case reflect.Type:
296 if expectedType.Kind() == reflect.Interface {
297 return cmpErrorTypeImplementsType(err, expectedType)
298 }
299 return cmpErrorTypeEqualType(err, expectedType)
300 case nil:
301 return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
302 }
303
304 expectedType := reflect.TypeOf(expected)
305 switch {
306 case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
307 return cmpErrorTypeEqualType(err, expectedType)
308 case isPtrToInterface(expectedType):
309 return cmpErrorTypeImplementsType(err, expectedType.Elem())
310 }
311 return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
312 }
313}
314
315func cmpErrorTypeFunc(err error, f func(error) bool) Result {
316 if f(err) {
317 return ResultSuccess
318 }
319 actual := "nil"
320 if err != nil {
321 actual = fmt.Sprintf("%s (%T)", err, err)
322 }
323 return ResultFailureTemplate(`error is {{ .Data.actual }}
324 {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
325 map[string]interface{}{"actual": actual})
326}
327
328func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
329 if err == nil {
330 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
331 }
332 errValue := reflect.ValueOf(err)
333 if errValue.Type() == expectedType {
334 return ResultSuccess
335 }
336 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
337}
338
339func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
340 if err == nil {
341 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
342 }
343 errValue := reflect.ValueOf(err)
344 if errValue.Type().Implements(expectedType) {
345 return ResultSuccess
346 }
347 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
348}
349
350func isPtrToInterface(typ reflect.Type) bool {
351 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
352}
353
354func isPtrToStruct(typ reflect.Type) bool {
355 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
356}