blob: 95d8e59da69bf8e207db7af447a2f792db6f5625 [file] [log] [blame]
vinokumaf7605fc2023-06-02 18:08:01 +05301package assert
2
3import (
4 "bytes"
5 "fmt"
6 "reflect"
7 "time"
8)
9
10type CompareType int
11
12const (
13 compareLess CompareType = iota - 1
14 compareEqual
15 compareGreater
16)
17
18var (
19 intType = reflect.TypeOf(int(1))
20 int8Type = reflect.TypeOf(int8(1))
21 int16Type = reflect.TypeOf(int16(1))
22 int32Type = reflect.TypeOf(int32(1))
23 int64Type = reflect.TypeOf(int64(1))
24
25 uintType = reflect.TypeOf(uint(1))
26 uint8Type = reflect.TypeOf(uint8(1))
27 uint16Type = reflect.TypeOf(uint16(1))
28 uint32Type = reflect.TypeOf(uint32(1))
29 uint64Type = reflect.TypeOf(uint64(1))
30
31 float32Type = reflect.TypeOf(float32(1))
32 float64Type = reflect.TypeOf(float64(1))
33
34 stringType = reflect.TypeOf("")
35
36 timeType = reflect.TypeOf(time.Time{})
37 bytesType = reflect.TypeOf([]byte{})
38)
39
40func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
41 obj1Value := reflect.ValueOf(obj1)
42 obj2Value := reflect.ValueOf(obj2)
43
44 // throughout this switch we try and avoid calling .Convert() if possible,
45 // as this has a pretty big performance impact
46 switch kind {
47 case reflect.Int:
48 {
49 intobj1, ok := obj1.(int)
50 if !ok {
51 intobj1 = obj1Value.Convert(intType).Interface().(int)
52 }
53 intobj2, ok := obj2.(int)
54 if !ok {
55 intobj2 = obj2Value.Convert(intType).Interface().(int)
56 }
57 if intobj1 > intobj2 {
58 return compareGreater, true
59 }
60 if intobj1 == intobj2 {
61 return compareEqual, true
62 }
63 if intobj1 < intobj2 {
64 return compareLess, true
65 }
66 }
67 case reflect.Int8:
68 {
69 int8obj1, ok := obj1.(int8)
70 if !ok {
71 int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
72 }
73 int8obj2, ok := obj2.(int8)
74 if !ok {
75 int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
76 }
77 if int8obj1 > int8obj2 {
78 return compareGreater, true
79 }
80 if int8obj1 == int8obj2 {
81 return compareEqual, true
82 }
83 if int8obj1 < int8obj2 {
84 return compareLess, true
85 }
86 }
87 case reflect.Int16:
88 {
89 int16obj1, ok := obj1.(int16)
90 if !ok {
91 int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
92 }
93 int16obj2, ok := obj2.(int16)
94 if !ok {
95 int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
96 }
97 if int16obj1 > int16obj2 {
98 return compareGreater, true
99 }
100 if int16obj1 == int16obj2 {
101 return compareEqual, true
102 }
103 if int16obj1 < int16obj2 {
104 return compareLess, true
105 }
106 }
107 case reflect.Int32:
108 {
109 int32obj1, ok := obj1.(int32)
110 if !ok {
111 int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
112 }
113 int32obj2, ok := obj2.(int32)
114 if !ok {
115 int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
116 }
117 if int32obj1 > int32obj2 {
118 return compareGreater, true
119 }
120 if int32obj1 == int32obj2 {
121 return compareEqual, true
122 }
123 if int32obj1 < int32obj2 {
124 return compareLess, true
125 }
126 }
127 case reflect.Int64:
128 {
129 int64obj1, ok := obj1.(int64)
130 if !ok {
131 int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
132 }
133 int64obj2, ok := obj2.(int64)
134 if !ok {
135 int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
136 }
137 if int64obj1 > int64obj2 {
138 return compareGreater, true
139 }
140 if int64obj1 == int64obj2 {
141 return compareEqual, true
142 }
143 if int64obj1 < int64obj2 {
144 return compareLess, true
145 }
146 }
147 case reflect.Uint:
148 {
149 uintobj1, ok := obj1.(uint)
150 if !ok {
151 uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
152 }
153 uintobj2, ok := obj2.(uint)
154 if !ok {
155 uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
156 }
157 if uintobj1 > uintobj2 {
158 return compareGreater, true
159 }
160 if uintobj1 == uintobj2 {
161 return compareEqual, true
162 }
163 if uintobj1 < uintobj2 {
164 return compareLess, true
165 }
166 }
167 case reflect.Uint8:
168 {
169 uint8obj1, ok := obj1.(uint8)
170 if !ok {
171 uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
172 }
173 uint8obj2, ok := obj2.(uint8)
174 if !ok {
175 uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
176 }
177 if uint8obj1 > uint8obj2 {
178 return compareGreater, true
179 }
180 if uint8obj1 == uint8obj2 {
181 return compareEqual, true
182 }
183 if uint8obj1 < uint8obj2 {
184 return compareLess, true
185 }
186 }
187 case reflect.Uint16:
188 {
189 uint16obj1, ok := obj1.(uint16)
190 if !ok {
191 uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
192 }
193 uint16obj2, ok := obj2.(uint16)
194 if !ok {
195 uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
196 }
197 if uint16obj1 > uint16obj2 {
198 return compareGreater, true
199 }
200 if uint16obj1 == uint16obj2 {
201 return compareEqual, true
202 }
203 if uint16obj1 < uint16obj2 {
204 return compareLess, true
205 }
206 }
207 case reflect.Uint32:
208 {
209 uint32obj1, ok := obj1.(uint32)
210 if !ok {
211 uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
212 }
213 uint32obj2, ok := obj2.(uint32)
214 if !ok {
215 uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
216 }
217 if uint32obj1 > uint32obj2 {
218 return compareGreater, true
219 }
220 if uint32obj1 == uint32obj2 {
221 return compareEqual, true
222 }
223 if uint32obj1 < uint32obj2 {
224 return compareLess, true
225 }
226 }
227 case reflect.Uint64:
228 {
229 uint64obj1, ok := obj1.(uint64)
230 if !ok {
231 uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
232 }
233 uint64obj2, ok := obj2.(uint64)
234 if !ok {
235 uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
236 }
237 if uint64obj1 > uint64obj2 {
238 return compareGreater, true
239 }
240 if uint64obj1 == uint64obj2 {
241 return compareEqual, true
242 }
243 if uint64obj1 < uint64obj2 {
244 return compareLess, true
245 }
246 }
247 case reflect.Float32:
248 {
249 float32obj1, ok := obj1.(float32)
250 if !ok {
251 float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
252 }
253 float32obj2, ok := obj2.(float32)
254 if !ok {
255 float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
256 }
257 if float32obj1 > float32obj2 {
258 return compareGreater, true
259 }
260 if float32obj1 == float32obj2 {
261 return compareEqual, true
262 }
263 if float32obj1 < float32obj2 {
264 return compareLess, true
265 }
266 }
267 case reflect.Float64:
268 {
269 float64obj1, ok := obj1.(float64)
270 if !ok {
271 float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
272 }
273 float64obj2, ok := obj2.(float64)
274 if !ok {
275 float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
276 }
277 if float64obj1 > float64obj2 {
278 return compareGreater, true
279 }
280 if float64obj1 == float64obj2 {
281 return compareEqual, true
282 }
283 if float64obj1 < float64obj2 {
284 return compareLess, true
285 }
286 }
287 case reflect.String:
288 {
289 stringobj1, ok := obj1.(string)
290 if !ok {
291 stringobj1 = obj1Value.Convert(stringType).Interface().(string)
292 }
293 stringobj2, ok := obj2.(string)
294 if !ok {
295 stringobj2 = obj2Value.Convert(stringType).Interface().(string)
296 }
297 if stringobj1 > stringobj2 {
298 return compareGreater, true
299 }
300 if stringobj1 == stringobj2 {
301 return compareEqual, true
302 }
303 if stringobj1 < stringobj2 {
304 return compareLess, true
305 }
306 }
307 // Check for known struct types we can check for compare results.
308 case reflect.Struct:
309 {
310 // All structs enter here. We're not interested in most types.
311 if !canConvert(obj1Value, timeType) {
312 break
313 }
314
315 // time.Time can compared!
316 timeObj1, ok := obj1.(time.Time)
317 if !ok {
318 timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
319 }
320
321 timeObj2, ok := obj2.(time.Time)
322 if !ok {
323 timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
324 }
325
326 return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
327 }
328 case reflect.Slice:
329 {
330 // We only care about the []byte type.
331 if !canConvert(obj1Value, bytesType) {
332 break
333 }
334
335 // []byte can be compared!
336 bytesObj1, ok := obj1.([]byte)
337 if !ok {
338 bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
339
340 }
341 bytesObj2, ok := obj2.([]byte)
342 if !ok {
343 bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
344 }
345
346 return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
347 }
348 }
349
350 return compareEqual, false
351}
352
353// Greater asserts that the first element is greater than the second
354//
355// assert.Greater(t, 2, 1)
356// assert.Greater(t, float64(2), float64(1))
357// assert.Greater(t, "b", "a")
358func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
359 if h, ok := t.(tHelper); ok {
360 h.Helper()
361 }
362 return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
363}
364
365// GreaterOrEqual asserts that the first element is greater than or equal to the second
366//
367// assert.GreaterOrEqual(t, 2, 1)
368// assert.GreaterOrEqual(t, 2, 2)
369// assert.GreaterOrEqual(t, "b", "a")
370// assert.GreaterOrEqual(t, "b", "b")
371func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
372 if h, ok := t.(tHelper); ok {
373 h.Helper()
374 }
375 return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
376}
377
378// Less asserts that the first element is less than the second
379//
380// assert.Less(t, 1, 2)
381// assert.Less(t, float64(1), float64(2))
382// assert.Less(t, "a", "b")
383func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
384 if h, ok := t.(tHelper); ok {
385 h.Helper()
386 }
387 return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
388}
389
390// LessOrEqual asserts that the first element is less than or equal to the second
391//
392// assert.LessOrEqual(t, 1, 2)
393// assert.LessOrEqual(t, 2, 2)
394// assert.LessOrEqual(t, "a", "b")
395// assert.LessOrEqual(t, "b", "b")
396func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
397 if h, ok := t.(tHelper); ok {
398 h.Helper()
399 }
400 return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
401}
402
403// Positive asserts that the specified element is positive
404//
405// assert.Positive(t, 1)
406// assert.Positive(t, 1.23)
407func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
408 if h, ok := t.(tHelper); ok {
409 h.Helper()
410 }
411 zero := reflect.Zero(reflect.TypeOf(e))
412 return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
413}
414
415// Negative asserts that the specified element is negative
416//
417// assert.Negative(t, -1)
418// assert.Negative(t, -1.23)
419func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
420 if h, ok := t.(tHelper); ok {
421 h.Helper()
422 }
423 zero := reflect.Zero(reflect.TypeOf(e))
424 return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
425}
426
427func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
428 if h, ok := t.(tHelper); ok {
429 h.Helper()
430 }
431
432 e1Kind := reflect.ValueOf(e1).Kind()
433 e2Kind := reflect.ValueOf(e2).Kind()
434 if e1Kind != e2Kind {
435 return Fail(t, "Elements should be the same type", msgAndArgs...)
436 }
437
438 compareResult, isComparable := compare(e1, e2, e1Kind)
439 if !isComparable {
440 return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
441 }
442
443 if !containsValue(allowedComparesResults, compareResult) {
444 return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
445 }
446
447 return true
448}
449
450func containsValue(values []CompareType, value CompareType) bool {
451 for _, v := range values {
452 if v == value {
453 return true
454 }
455 }
456
457 return false
458}