blob: 889c22e273c4a0a5a4de2655b1fb1440492d1da5 [file] [log] [blame]
khenaidooffe076b2019-01-15 16:08:08 -05001package cobra
2
3import (
4 "bytes"
5 "fmt"
6 "io"
7 "os"
8 "strings"
9)
10
11// GenZshCompletionFile generates zsh completion file.
12func (c *Command) GenZshCompletionFile(filename string) error {
13 outFile, err := os.Create(filename)
14 if err != nil {
15 return err
16 }
17 defer outFile.Close()
18
19 return c.GenZshCompletion(outFile)
20}
21
22// GenZshCompletion generates a zsh completion file and writes to the passed writer.
23func (c *Command) GenZshCompletion(w io.Writer) error {
24 buf := new(bytes.Buffer)
25
26 writeHeader(buf, c)
27 maxDepth := maxDepth(c)
28 writeLevelMapping(buf, maxDepth)
29 writeLevelCases(buf, maxDepth, c)
30
31 _, err := buf.WriteTo(w)
32 return err
33}
34
35func writeHeader(w io.Writer, cmd *Command) {
36 fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
37}
38
39func maxDepth(c *Command) int {
40 if len(c.Commands()) == 0 {
41 return 0
42 }
43 maxDepthSub := 0
44 for _, s := range c.Commands() {
45 subDepth := maxDepth(s)
46 if subDepth > maxDepthSub {
47 maxDepthSub = subDepth
48 }
49 }
50 return 1 + maxDepthSub
51}
52
53func writeLevelMapping(w io.Writer, numLevels int) {
54 fmt.Fprintln(w, `_arguments \`)
55 for i := 1; i <= numLevels; i++ {
56 fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
57 fmt.Fprintln(w)
58 }
59 fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
60 fmt.Fprintln(w)
61}
62
63func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
64 fmt.Fprintln(w, "case $state in")
65 defer fmt.Fprintln(w, "esac")
66
67 for i := 1; i <= maxDepth; i++ {
68 fmt.Fprintf(w, " level%d)\n", i)
69 writeLevel(w, root, i)
70 fmt.Fprintln(w, " ;;")
71 }
72 fmt.Fprintln(w, " *)")
73 fmt.Fprintln(w, " _arguments '*: :_files'")
74 fmt.Fprintln(w, " ;;")
75}
76
77func writeLevel(w io.Writer, root *Command, i int) {
78 fmt.Fprintf(w, " case $words[%d] in\n", i)
79 defer fmt.Fprintln(w, " esac")
80
81 commands := filterByLevel(root, i)
82 byParent := groupByParent(commands)
83
84 for p, c := range byParent {
85 names := names(c)
86 fmt.Fprintf(w, " %s)\n", p)
87 fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
88 fmt.Fprintln(w, " ;;")
89 }
90 fmt.Fprintln(w, " *)")
91 fmt.Fprintln(w, " _arguments '*: :_files'")
92 fmt.Fprintln(w, " ;;")
93
94}
95
96func filterByLevel(c *Command, l int) []*Command {
97 cs := make([]*Command, 0)
98 if l == 0 {
99 cs = append(cs, c)
100 return cs
101 }
102 for _, s := range c.Commands() {
103 cs = append(cs, filterByLevel(s, l-1)...)
104 }
105 return cs
106}
107
108func groupByParent(commands []*Command) map[string][]*Command {
109 m := make(map[string][]*Command)
110 for _, c := range commands {
111 parent := c.Parent()
112 if parent == nil {
113 continue
114 }
115 m[parent.Name()] = append(m[parent.Name()], c)
116 }
117 return m
118}
119
120func names(commands []*Command) []string {
121 ns := make([]string, len(commands))
122 for i, c := range commands {
123 ns[i] = c.Name()
124 }
125 return ns
126}