blob: cda29b2f04b71af51b90eeef90368d9ce378c808 [file] [log] [blame]
Naveen Sampath04696f72022-06-13 15:19:14 +05301// Copyright 2012 Jesse van den Kieboom. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flags
6
7import (
8 "fmt"
9 "reflect"
10 "strconv"
11 "strings"
12 "time"
13)
14
15// Marshaler is the interface implemented by types that can marshal themselves
16// to a string representation of the flag.
17type Marshaler interface {
18 // MarshalFlag marshals a flag value to its string representation.
19 MarshalFlag() (string, error)
20}
21
22// Unmarshaler is the interface implemented by types that can unmarshal a flag
23// argument to themselves. The provided value is directly passed from the
24// command line.
25type Unmarshaler interface {
26 // UnmarshalFlag unmarshals a string value representation to the flag
27 // value (which therefore needs to be a pointer receiver).
28 UnmarshalFlag(value string) error
29}
30
31// ValueValidator is the interface implemented by types that can validate a
32// flag argument themselves. The provided value is directly passed from the
33// command line.
34type ValueValidator interface {
35 // IsValidValue returns an error if the provided string value is valid for
36 // the flag.
37 IsValidValue(value string) error
38}
39
40func getBase(options multiTag, base int) (int, error) {
41 sbase := options.Get("base")
42
43 var err error
44 var ivbase int64
45
46 if sbase != "" {
47 ivbase, err = strconv.ParseInt(sbase, 10, 32)
48 base = int(ivbase)
49 }
50
51 return base, err
52}
53
54func convertMarshal(val reflect.Value) (bool, string, error) {
55 // Check first for the Marshaler interface
56 if val.Type().NumMethod() > 0 && val.CanInterface() {
57 if marshaler, ok := val.Interface().(Marshaler); ok {
58 ret, err := marshaler.MarshalFlag()
59 return true, ret, err
60 }
61 }
62
63 return false, "", nil
64}
65
66func convertToString(val reflect.Value, options multiTag) (string, error) {
67 if ok, ret, err := convertMarshal(val); ok {
68 return ret, err
69 }
70
71 tp := val.Type()
72
73 // Support for time.Duration
74 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
75 stringer := val.Interface().(fmt.Stringer)
76 return stringer.String(), nil
77 }
78
79 switch tp.Kind() {
80 case reflect.String:
81 return val.String(), nil
82 case reflect.Bool:
83 if val.Bool() {
84 return "true", nil
85 }
86
87 return "false", nil
88 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
89 base, err := getBase(options, 10)
90
91 if err != nil {
92 return "", err
93 }
94
95 return strconv.FormatInt(val.Int(), base), nil
96 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
97 base, err := getBase(options, 10)
98
99 if err != nil {
100 return "", err
101 }
102
103 return strconv.FormatUint(val.Uint(), base), nil
104 case reflect.Float32, reflect.Float64:
105 return strconv.FormatFloat(val.Float(), 'g', -1, tp.Bits()), nil
106 case reflect.Slice:
107 if val.Len() == 0 {
108 return "", nil
109 }
110
111 ret := "["
112
113 for i := 0; i < val.Len(); i++ {
114 if i != 0 {
115 ret += ", "
116 }
117
118 item, err := convertToString(val.Index(i), options)
119
120 if err != nil {
121 return "", err
122 }
123
124 ret += item
125 }
126
127 return ret + "]", nil
128 case reflect.Map:
129 ret := "{"
130
131 for i, key := range val.MapKeys() {
132 if i != 0 {
133 ret += ", "
134 }
135
136 keyitem, err := convertToString(key, options)
137
138 if err != nil {
139 return "", err
140 }
141
142 item, err := convertToString(val.MapIndex(key), options)
143
144 if err != nil {
145 return "", err
146 }
147
148 ret += keyitem + ":" + item
149 }
150
151 return ret + "}", nil
152 case reflect.Ptr:
153 return convertToString(reflect.Indirect(val), options)
154 case reflect.Interface:
155 if !val.IsNil() {
156 return convertToString(val.Elem(), options)
157 }
158 }
159
160 return "", nil
161}
162
163func convertUnmarshal(val string, retval reflect.Value) (bool, error) {
164 if retval.Type().NumMethod() > 0 && retval.CanInterface() {
165 if unmarshaler, ok := retval.Interface().(Unmarshaler); ok {
166 if retval.IsNil() {
167 retval.Set(reflect.New(retval.Type().Elem()))
168
169 // Re-assign from the new value
170 unmarshaler = retval.Interface().(Unmarshaler)
171 }
172
173 return true, unmarshaler.UnmarshalFlag(val)
174 }
175 }
176
177 if retval.Type().Kind() != reflect.Ptr && retval.CanAddr() {
178 return convertUnmarshal(val, retval.Addr())
179 }
180
181 if retval.Type().Kind() == reflect.Interface && !retval.IsNil() {
182 return convertUnmarshal(val, retval.Elem())
183 }
184
185 return false, nil
186}
187
188func convert(val string, retval reflect.Value, options multiTag) error {
189 if ok, err := convertUnmarshal(val, retval); ok {
190 return err
191 }
192
193 tp := retval.Type()
194
195 // Support for time.Duration
196 if tp == reflect.TypeOf((*time.Duration)(nil)).Elem() {
197 parsed, err := time.ParseDuration(val)
198
199 if err != nil {
200 return err
201 }
202
203 retval.SetInt(int64(parsed))
204 return nil
205 }
206
207 switch tp.Kind() {
208 case reflect.String:
209 retval.SetString(val)
210 case reflect.Bool:
211 if val == "" {
212 retval.SetBool(true)
213 } else {
214 b, err := strconv.ParseBool(val)
215
216 if err != nil {
217 return err
218 }
219
220 retval.SetBool(b)
221 }
222 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
223 base, err := getBase(options, 10)
224
225 if err != nil {
226 return err
227 }
228
229 parsed, err := strconv.ParseInt(val, base, tp.Bits())
230
231 if err != nil {
232 return err
233 }
234
235 retval.SetInt(parsed)
236 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
237 base, err := getBase(options, 10)
238
239 if err != nil {
240 return err
241 }
242
243 parsed, err := strconv.ParseUint(val, base, tp.Bits())
244
245 if err != nil {
246 return err
247 }
248
249 retval.SetUint(parsed)
250 case reflect.Float32, reflect.Float64:
251 parsed, err := strconv.ParseFloat(val, tp.Bits())
252
253 if err != nil {
254 return err
255 }
256
257 retval.SetFloat(parsed)
258 case reflect.Slice:
259 elemtp := tp.Elem()
260
261 elemvalptr := reflect.New(elemtp)
262 elemval := reflect.Indirect(elemvalptr)
263
264 if err := convert(val, elemval, options); err != nil {
265 return err
266 }
267
268 retval.Set(reflect.Append(retval, elemval))
269 case reflect.Map:
270 parts := strings.SplitN(val, ":", 2)
271
272 key := parts[0]
273 var value string
274
275 if len(parts) == 2 {
276 value = parts[1]
277 }
278
279 keytp := tp.Key()
280 keyval := reflect.New(keytp)
281
282 if err := convert(key, keyval, options); err != nil {
283 return err
284 }
285
286 valuetp := tp.Elem()
287 valueval := reflect.New(valuetp)
288
289 if err := convert(value, valueval, options); err != nil {
290 return err
291 }
292
293 if retval.IsNil() {
294 retval.Set(reflect.MakeMap(tp))
295 }
296
297 retval.SetMapIndex(reflect.Indirect(keyval), reflect.Indirect(valueval))
298 case reflect.Ptr:
299 if retval.IsNil() {
300 retval.Set(reflect.New(retval.Type().Elem()))
301 }
302
303 return convert(val, reflect.Indirect(retval), options)
304 case reflect.Interface:
305 if !retval.IsNil() {
306 return convert(val, retval.Elem(), options)
307 }
308 }
309
310 return nil
311}
312
313func isPrint(s string) bool {
314 for _, c := range s {
315 if !strconv.IsPrint(c) {
316 return false
317 }
318 }
319
320 return true
321}
322
323func quoteIfNeeded(s string) string {
324 if !isPrint(s) {
325 return strconv.Quote(s)
326 }
327
328 return s
329}
330
331func quoteIfNeededV(s []string) []string {
332 ret := make([]string, len(s))
333
334 for i, v := range s {
335 ret[i] = quoteIfNeeded(v)
336 }
337
338 return ret
339}
340
341func quoteV(s []string) []string {
342 ret := make([]string, len(s))
343
344 for i, v := range s {
345 ret[i] = strconv.Quote(v)
346 }
347
348 return ret
349}
350
351func unquoteIfPossible(s string) (string, error) {
352 if len(s) == 0 || s[0] != '"' {
353 return s, nil
354 }
355
356 return strconv.Unquote(s)
357}