blob: 3bb22a9718eb84a29946d20c7acaf3fad5d7952f [file] [log] [blame]
Elia Battistonc8d0d462022-02-22 16:30:51 +01001package assert
2
3import (
4 "fmt"
5 "reflect"
Elia Battiston4750d3c2022-07-14 13:24:56 +00006 "time"
Elia Battistonc8d0d462022-02-22 16:30:51 +01007)
8
9type CompareType int
10
11const (
12 compareLess CompareType = iota - 1
13 compareEqual
14 compareGreater
15)
16
17var (
18 intType = reflect.TypeOf(int(1))
19 int8Type = reflect.TypeOf(int8(1))
20 int16Type = reflect.TypeOf(int16(1))
21 int32Type = reflect.TypeOf(int32(1))
22 int64Type = reflect.TypeOf(int64(1))
23
24 uintType = reflect.TypeOf(uint(1))
25 uint8Type = reflect.TypeOf(uint8(1))
26 uint16Type = reflect.TypeOf(uint16(1))
27 uint32Type = reflect.TypeOf(uint32(1))
28 uint64Type = reflect.TypeOf(uint64(1))
29
30 float32Type = reflect.TypeOf(float32(1))
31 float64Type = reflect.TypeOf(float64(1))
32
33 stringType = reflect.TypeOf("")
Elia Battiston4750d3c2022-07-14 13:24:56 +000034
35 timeType = reflect.TypeOf(time.Time{})
Elia Battistonc8d0d462022-02-22 16:30:51 +010036)
37
38func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
39 obj1Value := reflect.ValueOf(obj1)
40 obj2Value := reflect.ValueOf(obj2)
41
42 // throughout this switch we try and avoid calling .Convert() if possible,
43 // as this has a pretty big performance impact
44 switch kind {
45 case reflect.Int:
46 {
47 intobj1, ok := obj1.(int)
48 if !ok {
49 intobj1 = obj1Value.Convert(intType).Interface().(int)
50 }
51 intobj2, ok := obj2.(int)
52 if !ok {
53 intobj2 = obj2Value.Convert(intType).Interface().(int)
54 }
55 if intobj1 > intobj2 {
56 return compareGreater, true
57 }
58 if intobj1 == intobj2 {
59 return compareEqual, true
60 }
61 if intobj1 < intobj2 {
62 return compareLess, true
63 }
64 }
65 case reflect.Int8:
66 {
67 int8obj1, ok := obj1.(int8)
68 if !ok {
69 int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
70 }
71 int8obj2, ok := obj2.(int8)
72 if !ok {
73 int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
74 }
75 if int8obj1 > int8obj2 {
76 return compareGreater, true
77 }
78 if int8obj1 == int8obj2 {
79 return compareEqual, true
80 }
81 if int8obj1 < int8obj2 {
82 return compareLess, true
83 }
84 }
85 case reflect.Int16:
86 {
87 int16obj1, ok := obj1.(int16)
88 if !ok {
89 int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
90 }
91 int16obj2, ok := obj2.(int16)
92 if !ok {
93 int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
94 }
95 if int16obj1 > int16obj2 {
96 return compareGreater, true
97 }
98 if int16obj1 == int16obj2 {
99 return compareEqual, true
100 }
101 if int16obj1 < int16obj2 {
102 return compareLess, true
103 }
104 }
105 case reflect.Int32:
106 {
107 int32obj1, ok := obj1.(int32)
108 if !ok {
109 int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
110 }
111 int32obj2, ok := obj2.(int32)
112 if !ok {
113 int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
114 }
115 if int32obj1 > int32obj2 {
116 return compareGreater, true
117 }
118 if int32obj1 == int32obj2 {
119 return compareEqual, true
120 }
121 if int32obj1 < int32obj2 {
122 return compareLess, true
123 }
124 }
125 case reflect.Int64:
126 {
127 int64obj1, ok := obj1.(int64)
128 if !ok {
129 int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
130 }
131 int64obj2, ok := obj2.(int64)
132 if !ok {
133 int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
134 }
135 if int64obj1 > int64obj2 {
136 return compareGreater, true
137 }
138 if int64obj1 == int64obj2 {
139 return compareEqual, true
140 }
141 if int64obj1 < int64obj2 {
142 return compareLess, true
143 }
144 }
145 case reflect.Uint:
146 {
147 uintobj1, ok := obj1.(uint)
148 if !ok {
149 uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
150 }
151 uintobj2, ok := obj2.(uint)
152 if !ok {
153 uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
154 }
155 if uintobj1 > uintobj2 {
156 return compareGreater, true
157 }
158 if uintobj1 == uintobj2 {
159 return compareEqual, true
160 }
161 if uintobj1 < uintobj2 {
162 return compareLess, true
163 }
164 }
165 case reflect.Uint8:
166 {
167 uint8obj1, ok := obj1.(uint8)
168 if !ok {
169 uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
170 }
171 uint8obj2, ok := obj2.(uint8)
172 if !ok {
173 uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
174 }
175 if uint8obj1 > uint8obj2 {
176 return compareGreater, true
177 }
178 if uint8obj1 == uint8obj2 {
179 return compareEqual, true
180 }
181 if uint8obj1 < uint8obj2 {
182 return compareLess, true
183 }
184 }
185 case reflect.Uint16:
186 {
187 uint16obj1, ok := obj1.(uint16)
188 if !ok {
189 uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
190 }
191 uint16obj2, ok := obj2.(uint16)
192 if !ok {
193 uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
194 }
195 if uint16obj1 > uint16obj2 {
196 return compareGreater, true
197 }
198 if uint16obj1 == uint16obj2 {
199 return compareEqual, true
200 }
201 if uint16obj1 < uint16obj2 {
202 return compareLess, true
203 }
204 }
205 case reflect.Uint32:
206 {
207 uint32obj1, ok := obj1.(uint32)
208 if !ok {
209 uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
210 }
211 uint32obj2, ok := obj2.(uint32)
212 if !ok {
213 uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
214 }
215 if uint32obj1 > uint32obj2 {
216 return compareGreater, true
217 }
218 if uint32obj1 == uint32obj2 {
219 return compareEqual, true
220 }
221 if uint32obj1 < uint32obj2 {
222 return compareLess, true
223 }
224 }
225 case reflect.Uint64:
226 {
227 uint64obj1, ok := obj1.(uint64)
228 if !ok {
229 uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
230 }
231 uint64obj2, ok := obj2.(uint64)
232 if !ok {
233 uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
234 }
235 if uint64obj1 > uint64obj2 {
236 return compareGreater, true
237 }
238 if uint64obj1 == uint64obj2 {
239 return compareEqual, true
240 }
241 if uint64obj1 < uint64obj2 {
242 return compareLess, true
243 }
244 }
245 case reflect.Float32:
246 {
247 float32obj1, ok := obj1.(float32)
248 if !ok {
249 float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
250 }
251 float32obj2, ok := obj2.(float32)
252 if !ok {
253 float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
254 }
255 if float32obj1 > float32obj2 {
256 return compareGreater, true
257 }
258 if float32obj1 == float32obj2 {
259 return compareEqual, true
260 }
261 if float32obj1 < float32obj2 {
262 return compareLess, true
263 }
264 }
265 case reflect.Float64:
266 {
267 float64obj1, ok := obj1.(float64)
268 if !ok {
269 float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
270 }
271 float64obj2, ok := obj2.(float64)
272 if !ok {
273 float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
274 }
275 if float64obj1 > float64obj2 {
276 return compareGreater, true
277 }
278 if float64obj1 == float64obj2 {
279 return compareEqual, true
280 }
281 if float64obj1 < float64obj2 {
282 return compareLess, true
283 }
284 }
285 case reflect.String:
286 {
287 stringobj1, ok := obj1.(string)
288 if !ok {
289 stringobj1 = obj1Value.Convert(stringType).Interface().(string)
290 }
291 stringobj2, ok := obj2.(string)
292 if !ok {
293 stringobj2 = obj2Value.Convert(stringType).Interface().(string)
294 }
295 if stringobj1 > stringobj2 {
296 return compareGreater, true
297 }
298 if stringobj1 == stringobj2 {
299 return compareEqual, true
300 }
301 if stringobj1 < stringobj2 {
302 return compareLess, true
303 }
304 }
Elia Battiston4750d3c2022-07-14 13:24:56 +0000305 // Check for known struct types we can check for compare results.
306 case reflect.Struct:
307 {
308 // All structs enter here. We're not interested in most types.
309 if !canConvert(obj1Value, timeType) {
310 break
311 }
312
313 // time.Time can compared!
314 timeObj1, ok := obj1.(time.Time)
315 if !ok {
316 timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
317 }
318
319 timeObj2, ok := obj2.(time.Time)
320 if !ok {
321 timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
322 }
323
324 return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
325 }
Elia Battistonc8d0d462022-02-22 16:30:51 +0100326 }
327
328 return compareEqual, false
329}
330
331// Greater asserts that the first element is greater than the second
332//
333// assert.Greater(t, 2, 1)
334// assert.Greater(t, float64(2), float64(1))
335// assert.Greater(t, "b", "a")
336func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000337 if h, ok := t.(tHelper); ok {
338 h.Helper()
339 }
340 return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100341}
342
343// GreaterOrEqual asserts that the first element is greater than or equal to the second
344//
345// assert.GreaterOrEqual(t, 2, 1)
346// assert.GreaterOrEqual(t, 2, 2)
347// assert.GreaterOrEqual(t, "b", "a")
348// assert.GreaterOrEqual(t, "b", "b")
349func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000350 if h, ok := t.(tHelper); ok {
351 h.Helper()
352 }
353 return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100354}
355
356// Less asserts that the first element is less than the second
357//
358// assert.Less(t, 1, 2)
359// assert.Less(t, float64(1), float64(2))
360// assert.Less(t, "a", "b")
361func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000362 if h, ok := t.(tHelper); ok {
363 h.Helper()
364 }
365 return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100366}
367
368// LessOrEqual asserts that the first element is less than or equal to the second
369//
370// assert.LessOrEqual(t, 1, 2)
371// assert.LessOrEqual(t, 2, 2)
372// assert.LessOrEqual(t, "a", "b")
373// assert.LessOrEqual(t, "b", "b")
374func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000375 if h, ok := t.(tHelper); ok {
376 h.Helper()
377 }
378 return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100379}
380
381// Positive asserts that the specified element is positive
382//
383// assert.Positive(t, 1)
384// assert.Positive(t, 1.23)
385func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000386 if h, ok := t.(tHelper); ok {
387 h.Helper()
388 }
Elia Battistonc8d0d462022-02-22 16:30:51 +0100389 zero := reflect.Zero(reflect.TypeOf(e))
Elia Battiston4750d3c2022-07-14 13:24:56 +0000390 return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100391}
392
393// Negative asserts that the specified element is negative
394//
395// assert.Negative(t, -1)
396// assert.Negative(t, -1.23)
397func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
Elia Battiston4750d3c2022-07-14 13:24:56 +0000398 if h, ok := t.(tHelper); ok {
399 h.Helper()
400 }
Elia Battistonc8d0d462022-02-22 16:30:51 +0100401 zero := reflect.Zero(reflect.TypeOf(e))
Elia Battiston4750d3c2022-07-14 13:24:56 +0000402 return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
Elia Battistonc8d0d462022-02-22 16:30:51 +0100403}
404
405func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
406 if h, ok := t.(tHelper); ok {
407 h.Helper()
408 }
409
410 e1Kind := reflect.ValueOf(e1).Kind()
411 e2Kind := reflect.ValueOf(e2).Kind()
412 if e1Kind != e2Kind {
413 return Fail(t, "Elements should be the same type", msgAndArgs...)
414 }
415
416 compareResult, isComparable := compare(e1, e2, e1Kind)
417 if !isComparable {
418 return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
419 }
420
421 if !containsValue(allowedComparesResults, compareResult) {
422 return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
423 }
424
425 return true
426}
427
428func containsValue(values []CompareType, value CompareType) bool {
429 for _, v := range values {
430 if v == value {
431 return true
432 }
433 }
434
435 return false
436}