blob: 8a5d0e8d35b413436a1f860fc7a83d0501827c64 [file] [log] [blame]
Matteo Scandoloabf872d2020-12-14 08:22:06 -10001package source // import "gotest.tools/internal/source"
2
3import (
4 "bytes"
5 "fmt"
6 "go/ast"
7 "go/format"
8 "go/parser"
9 "go/token"
10 "os"
11 "runtime"
12 "strconv"
13 "strings"
14
15 "github.com/pkg/errors"
16)
17
18const baseStackIndex = 1
19
20// FormattedCallExprArg returns the argument from an ast.CallExpr at the
21// index in the call stack. The argument is formatted using FormatNode.
22func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
23 args, err := CallExprArgs(stackIndex + 1)
24 if err != nil {
25 return "", err
26 }
27 if argPos >= len(args) {
28 return "", errors.New("failed to find expression")
29 }
30 return FormatNode(args[argPos])
31}
32
33// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
34// the index in the call stack.
35func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
36 _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
37 if !ok {
38 return nil, errors.New("failed to get call stack")
39 }
40 debug("call stack position: %s:%d", filename, lineNum)
41
42 node, err := getNodeAtLine(filename, lineNum)
43 if err != nil {
44 return nil, err
45 }
46 debug("found node: %s", debugFormatNode{node})
47
48 return getCallExprArgs(node)
49}
50
51func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
52 fileset := token.NewFileSet()
53 astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
54 if err != nil {
55 return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
56 }
57
58 if node := scanToLine(fileset, astFile, lineNum); node != nil {
59 return node, nil
60 }
61 if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
62 node, err := guessDefer(node)
63 if err != nil || node != nil {
64 return node, err
65 }
66 }
67 return nil, errors.Errorf(
68 "failed to find an expression on line %d in %s", lineNum, filename)
69}
70
71func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
72 var matchedNode ast.Node
73 ast.Inspect(node, func(node ast.Node) bool {
74 switch {
75 case node == nil || matchedNode != nil:
76 return false
77 case nodePosition(fileset, node).Line == lineNum:
78 matchedNode = node
79 return false
80 }
81 return true
82 })
83 return matchedNode
84}
85
86// In golang 1.9 the line number changed from being the line where the statement
87// ended to the line where the statement began.
88func nodePosition(fileset *token.FileSet, node ast.Node) token.Position {
89 if goVersionBefore19 {
90 return fileset.Position(node.End())
91 }
92 return fileset.Position(node.Pos())
93}
94
95var goVersionBefore19 = func() bool {
96 version := runtime.Version()
97 // not a release version
98 if !strings.HasPrefix(version, "go") {
99 return false
100 }
101 version = strings.TrimPrefix(version, "go")
102 parts := strings.Split(version, ".")
103 if len(parts) < 2 {
104 return false
105 }
106 minor, err := strconv.ParseInt(parts[1], 10, 32)
107 return err == nil && parts[0] == "1" && minor < 9
108}()
109
110func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
111 visitor := &callExprVisitor{}
112 ast.Walk(visitor, node)
113 if visitor.expr == nil {
114 return nil, errors.New("failed to find call expression")
115 }
116 debug("callExpr: %s", debugFormatNode{visitor.expr})
117 return visitor.expr.Args, nil
118}
119
120type callExprVisitor struct {
121 expr *ast.CallExpr
122}
123
124func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
125 if v.expr != nil || node == nil {
126 return nil
127 }
128 debug("visit: %s", debugFormatNode{node})
129
130 switch typed := node.(type) {
131 case *ast.CallExpr:
132 v.expr = typed
133 return nil
134 case *ast.DeferStmt:
135 ast.Walk(v, typed.Call.Fun)
136 return nil
137 }
138 return v
139}
140
141// FormatNode using go/format.Node and return the result as a string
142func FormatNode(node ast.Node) (string, error) {
143 buf := new(bytes.Buffer)
144 err := format.Node(buf, token.NewFileSet(), node)
145 return buf.String(), err
146}
147
148var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
149
150func debug(format string, args ...interface{}) {
151 if debugEnabled {
152 fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
153 }
154}
155
156type debugFormatNode struct {
157 ast.Node
158}
159
160func (n debugFormatNode) String() string {
161 out, err := FormatNode(n.Node)
162 if err != nil {
163 return fmt.Sprintf("failed to format %s: %s", n.Node, err)
164 }
165 return fmt.Sprintf("(%T) %s", n.Node, out)
166}