blob: c938c97bec147c1e6532117a8254cc51ab813944 [file] [log] [blame]
Matteo Scandoloabf872d2020-12-14 08:22:06 -10001package format
2
3import (
4 "bytes"
5 "fmt"
6 "strings"
7 "unicode"
8
9 "gotest.tools/internal/difflib"
10)
11
12const (
13 contextLines = 2
14)
15
16// DiffConfig for a unified diff
17type DiffConfig struct {
18 A string
19 B string
20 From string
21 To string
22}
23
24// UnifiedDiff is a modified version of difflib.WriteUnifiedDiff with better
25// support for showing the whitespace differences.
26func UnifiedDiff(conf DiffConfig) string {
27 a := strings.SplitAfter(conf.A, "\n")
28 b := strings.SplitAfter(conf.B, "\n")
29 groups := difflib.NewMatcher(a, b).GetGroupedOpCodes(contextLines)
30 if len(groups) == 0 {
31 return ""
32 }
33
34 buf := new(bytes.Buffer)
35 writeFormat := func(format string, args ...interface{}) {
36 buf.WriteString(fmt.Sprintf(format, args...))
37 }
38 writeLine := func(prefix string, s string) {
39 buf.WriteString(prefix + s)
40 }
41 if hasWhitespaceDiffLines(groups, a, b) {
42 writeLine = visibleWhitespaceLine(writeLine)
43 }
44 formatHeader(writeFormat, conf)
45 for _, group := range groups {
46 formatRangeLine(writeFormat, group)
47 for _, opCode := range group {
48 in, out := a[opCode.I1:opCode.I2], b[opCode.J1:opCode.J2]
49 switch opCode.Tag {
50 case 'e':
51 formatLines(writeLine, " ", in)
52 case 'r':
53 formatLines(writeLine, "-", in)
54 formatLines(writeLine, "+", out)
55 case 'd':
56 formatLines(writeLine, "-", in)
57 case 'i':
58 formatLines(writeLine, "+", out)
59 }
60 }
61 }
62 return buf.String()
63}
64
65// hasWhitespaceDiffLines returns true if any diff groups is only different
66// because of whitespace characters.
67func hasWhitespaceDiffLines(groups [][]difflib.OpCode, a, b []string) bool {
68 for _, group := range groups {
69 in, out := new(bytes.Buffer), new(bytes.Buffer)
70 for _, opCode := range group {
71 if opCode.Tag == 'e' {
72 continue
73 }
74 for _, line := range a[opCode.I1:opCode.I2] {
75 in.WriteString(line)
76 }
77 for _, line := range b[opCode.J1:opCode.J2] {
78 out.WriteString(line)
79 }
80 }
81 if removeWhitespace(in.String()) == removeWhitespace(out.String()) {
82 return true
83 }
84 }
85 return false
86}
87
88func removeWhitespace(s string) string {
89 var result []rune
90 for _, r := range s {
91 if !unicode.IsSpace(r) {
92 result = append(result, r)
93 }
94 }
95 return string(result)
96}
97
98func visibleWhitespaceLine(ws func(string, string)) func(string, string) {
99 mapToVisibleSpace := func(r rune) rune {
100 switch r {
101 case '\n':
102 case ' ':
103 return '·'
104 case '\t':
105 return '▷'
106 case '\v':
107 return '▽'
108 case '\r':
109 return '↵'
110 case '\f':
111 return '↓'
112 default:
113 if unicode.IsSpace(r) {
114 return '�'
115 }
116 }
117 return r
118 }
119 return func(prefix, s string) {
120 ws(prefix, strings.Map(mapToVisibleSpace, s))
121 }
122}
123
124func formatHeader(wf func(string, ...interface{}), conf DiffConfig) {
125 if conf.From != "" || conf.To != "" {
126 wf("--- %s\n", conf.From)
127 wf("+++ %s\n", conf.To)
128 }
129}
130
131func formatRangeLine(wf func(string, ...interface{}), group []difflib.OpCode) {
132 first, last := group[0], group[len(group)-1]
133 range1 := formatRangeUnified(first.I1, last.I2)
134 range2 := formatRangeUnified(first.J1, last.J2)
135 wf("@@ -%s +%s @@\n", range1, range2)
136}
137
138// Convert range to the "ed" format
139func formatRangeUnified(start, stop int) string {
140 // Per the diff spec at http://www.unix.org/single_unix_specification/
141 beginning := start + 1 // lines start numbering with one
142 length := stop - start
143 if length == 1 {
144 return fmt.Sprintf("%d", beginning)
145 }
146 if length == 0 {
147 beginning-- // empty ranges begin at line just before the range
148 }
149 return fmt.Sprintf("%d,%d", beginning, length)
150}
151
152func formatLines(writeLine func(string, string), prefix string, lines []string) {
153 for _, line := range lines {
154 writeLine(prefix, line)
155 }
156 // Add a newline if the last line is missing one so that the diff displays
157 // properly.
158 if !strings.HasSuffix(lines[len(lines)-1], "\n") {
159 writeLine("", "\n")
160 }
161}