blob: 949d93961990961c57b3b09dae00fddd411c6b5f [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package assert
2
3import (
4 "fmt"
5 "go/ast"
6
7 "gotest.tools/assert/cmp"
8 "gotest.tools/internal/format"
9 "gotest.tools/internal/source"
10)
11
12func runComparison(
13 t TestingT,
14 argSelector argSelector,
15 f cmp.Comparison,
16 msgAndArgs ...interface{},
17) bool {
18 if ht, ok := t.(helperT); ok {
19 ht.Helper()
20 }
21 result := f()
22 if result.Success() {
23 return true
24 }
25
26 var message string
27 switch typed := result.(type) {
28 case resultWithComparisonArgs:
29 const stackIndex = 3 // Assert/Check, assert, runComparison
30 args, err := source.CallExprArgs(stackIndex)
31 if err != nil {
32 t.Log(err.Error())
33 }
34 message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
35 case resultBasic:
36 message = typed.FailureMessage()
37 default:
38 message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
39 }
40
41 t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
42 return false
43}
44
45type resultWithComparisonArgs interface {
46 FailureMessage(args []ast.Expr) string
47}
48
49type resultBasic interface {
50 FailureMessage() string
51}
52
53// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
54// easy to read when printed and contain relevant information to an assertion.
55//
56// Ident and SelectorExpr are included because they print nicely and the variable
57// names may provide additional context to their values.
58// BasicLit and CompositeLit are excluded because their source is equivalent to
59// their value, which is already available.
60// Other types are ignored for now, but could be added if they are relevant.
61func filterPrintableExpr(args []ast.Expr) []ast.Expr {
62 result := make([]ast.Expr, len(args))
63 for i, arg := range args {
64 if isShortPrintableExpr(arg) {
65 result[i] = arg
66 continue
67 }
68
69 if starExpr, ok := arg.(*ast.StarExpr); ok {
70 result[i] = starExpr.X
71 continue
72 }
73 }
74 return result
75}
76
77func isShortPrintableExpr(expr ast.Expr) bool {
78 switch expr.(type) {
79 case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
80 return true
81 case *ast.BinaryExpr, *ast.UnaryExpr:
82 return true
83 default:
84 // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
85 return false
86 }
87}
88
89type argSelector func([]ast.Expr) []ast.Expr
90
91func argsAfterT(args []ast.Expr) []ast.Expr {
92 if len(args) < 1 {
93 return nil
94 }
95 return args[1:]
96}
97
98func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
99 if len(args) < 1 {
100 return nil
101 }
102 if callExpr, ok := args[1].(*ast.CallExpr); ok {
103 return callExpr.Args
104 }
105 return nil
106}