blob: 984aac89cc0a88f0b794afae31b4856db19d6e8c [file] [log] [blame]
Zack Williamse940c7a2019-08-21 14:25:39 -07001// Copyright 2012 Jesse van den Kieboom. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flags
6
7import (
8 "fmt"
9 "reflect"
10 "strconv"
11 "strings"
12 "time"
13)
14
15// Marshaler is the interface implemented by types that can marshal themselves
16// to a string representation of the flag.
17type Marshaler interface {
18 // MarshalFlag marshals a flag value to its string representation.
19 MarshalFlag() (string, error)
20}
21
22// Unmarshaler is the interface implemented by types that can unmarshal a flag
23// argument to themselves. The provided value is directly passed from the
24// command line.
25type Unmarshaler interface {
26 // UnmarshalFlag unmarshals a string value representation to the flag
27 // value (which therefore needs to be a pointer receiver).
28 UnmarshalFlag(value string) error
29}
30
31func getBase(options multiTag, base int) (int, error) {
32 sbase := options.Get("base")
33
34 var err error
35 var ivbase int64
36
37 if sbase != "" {
38 ivbase, err = strconv.ParseInt(sbase, 10, 32)
39 base = int(ivbase)
40 }
41
42 return base, err
43}
44
45func convertMarshal(val reflect.Value) (bool, string, error) {
46 // Check first for the Marshaler interface
47 if val.Type().NumMethod() > 0 && val.CanInterface() {
48 if marshaler, ok := val.Interface().(Marshaler); ok {
49 ret, err := marshaler.MarshalFlag()
50 return true, ret, err
51 }
52 }
53
54 return false, "", nil
55}
56
57func convertToString(val reflect.Value, options multiTag) (string, error) {
58 if ok, ret, err := convertMarshal(val); ok {
59 return ret, err
60 }
61
62 tp := val.Type()
63
64 // Support for time.Duration
65 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
66 stringer := val.Interface().(fmt.Stringer)
67 return stringer.String(), nil
68 }
69
70 switch tp.Kind() {
71 case reflect.String:
72 return val.String(), nil
73 case reflect.Bool:
74 if val.Bool() {
75 return "true", nil
76 }
77
78 return "false", nil
79 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
80 base, err := getBase(options, 10)
81
82 if err != nil {
83 return "", err
84 }
85
86 return strconv.FormatInt(val.Int(), base), nil
87 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
88 base, err := getBase(options, 10)
89
90 if err != nil {
91 return "", err
92 }
93
94 return strconv.FormatUint(val.Uint(), base), nil
95 case reflect.Float32, reflect.Float64:
96 return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil
97 case reflect.Slice:
98 if val.Len() == 0 {
99 return "", nil
100 }
101
102 ret := "["
103
104 for i := 0; i < val.Len(); i++ {
105 if i != 0 {
106 ret += ", "
107 }
108
109 item, err := convertToString(val.Index(i), options)
110
111 if err != nil {
112 return "", err
113 }
114
115 ret += item
116 }
117
118 return ret + "]", nil
119 case reflect.Map:
120 ret := "{"
121
122 for i, key := range val.MapKeys() {
123 if i != 0 {
124 ret += ", "
125 }
126
127 keyitem, err := convertToString(key, options)
128
129 if err != nil {
130 return "", err
131 }
132
133 item, err := convertToString(val.MapIndex(key), options)
134
135 if err != nil {
136 return "", err
137 }
138
139 ret += keyitem + ":" + item
140 }
141
142 return ret + "}", nil
143 case reflect.Ptr:
144 return convertToString(reflect.Indirect(val), options)
145 case reflect.Interface:
146 if !val.IsNil() {
147 return convertToString(val.Elem(), options)
148 }
149 }
150
151 return "", nil
152}
153
154func convertUnmarshal(val string, retval reflect.Value) (bool, error) {
155 if retval.Type().NumMethod() > 0 && retval.CanInterface() {
156 if unmarshaler, ok := retval.Interface().(Unmarshaler); ok {
157 if retval.IsNil() {
158 retval.Set(reflect.New(retval.Type().Elem()))
159
160 // Re-assign from the new value
161 unmarshaler = retval.Interface().(Unmarshaler)
162 }
163
164 return true, unmarshaler.UnmarshalFlag(val)
165 }
166 }
167
168 if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() {
169 return convertUnmarshal(val, retval.Addr())
170 }
171
172 if retval.Type().Kind() == reflect.Interface && !retval.IsNil() {
173 return convertUnmarshal(val, retval.Elem())
174 }
175
176 return false, nil
177}
178
179func convert(val string, retval reflect.Value, options multiTag) error {
180 if ok, err := convertUnmarshal(val, retval); ok {
181 return err
182 }
183
184 tp := retval.Type()
185
186 // Support for time.Duration
187 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
188 parsed, err := time.ParseDuration(val)
189
190 if err != nil {
191 return err
192 }
193
194 retval.SetInt(int64(parsed))
195 return nil
196 }
197
198 switch tp.Kind() {
199 case reflect.String:
200 retval.SetString(val)
201 case reflect.Bool:
202 if val == "" {
203 retval.SetBool(true)
204 } else {
205 b, err := strconv.ParseBool(val)
206
207 if err != nil {
208 return err
209 }
210
211 retval.SetBool(b)
212 }
213 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
214 base, err := getBase(options, 10)
215
216 if err != nil {
217 return err
218 }
219
220 parsed, err := strconv.ParseInt(val, base, tp.Bits())
221
222 if err != nil {
223 return err
224 }
225
226 retval.SetInt(parsed)
227 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
228 base, err := getBase(options, 10)
229
230 if err != nil {
231 return err
232 }
233
234 parsed, err := strconv.ParseUint(val, base, tp.Bits())
235
236 if err != nil {
237 return err
238 }
239
240 retval.SetUint(parsed)
241 case reflect.Float32, reflect.Float64:
242 parsed, err := strconv.ParseFloat(val, tp.Bits())
243
244 if err != nil {
245 return err
246 }
247
248 retval.SetFloat(parsed)
249 case reflect.Slice:
250 elemtp := tp.Elem()
251
252 elemvalptr := reflect.New(elemtp)
253 elemval := reflect.Indirect(elemvalptr)
254
255 if err := convert(val, elemval, options); err != nil {
256 return err
257 }
258
259 retval.Set(reflect.Append(retval, elemval))
260 case reflect.Map:
261 parts := strings.SplitN(val, ":", 2)
262
263 key := parts[0]
264 var value string
265
266 if len(parts) == 2 {
267 value = parts[1]
268 }
269
270 keytp := tp.Key()
271 keyval := reflect.New(keytp)
272
273 if err := convert(key, keyval, options); err != nil {
274 return err
275 }
276
277 valuetp := tp.Elem()
278 valueval := reflect.New(valuetp)
279
280 if err := convert(value, valueval, options); err != nil {
281 return err
282 }
283
284 if retval.IsNil() {
285 retval.Set(reflect.MakeMap(tp))
286 }
287
288 retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval))
289 case reflect.Ptr:
290 if retval.IsNil() {
291 retval.Set(reflect.New(retval.Type().Elem()))
292 }
293
294 return convert(val, reflect.Indirect(retval), options)
295 case reflect.Interface:
296 if !retval.IsNil() {
297 return convert(val, retval.Elem(), options)
298 }
299 }
300
301 return nil
302}
303
304func isPrint(s string) bool {
305 for _, c := range s {
306 if !strconv.IsPrint(c) {
307 return false
308 }
309 }
310
311 return true
312}
313
314func quoteIfNeeded(s string) string {
315 if !isPrint(s) {
316 return strconv.Quote(s)
317 }
318
319 return s
320}
321
322func quoteIfNeededV(s []string) []string {
323 ret := make([]string, len(s))
324
325 for i, v := range s {
326 ret[i] = quoteIfNeeded(v)
327 }
328
329 return ret
330}
331
332func quoteV(s []string) []string {
333 ret := make([]string, len(s))
334
335 for i, v := range s {
336 ret[i] = strconv.Quote(v)
337 }
338
339 return ret
340}
341
342func unquoteIfPossible(s string) (string, error) {
343 if len(s) == 0 || s[0] != '"' {
344 return s, nil
345 }
346
347 return strconv.Unquote(s)
348}