blob: f58cdf45820137dd617fd57ca34f2fc1bcaea858 [file] [log] [blame]
Prince Pereirac1c21d62021-04-22 08:38:15 +00001package copier
2
3import (
4 "database/sql"
5 "database/sql/driver"
6 "fmt"
7 "reflect"
8 "strings"
9)
10
11// These flags define options for tag handling
12const (
13 // Denotes that a destination field must be copied to. If copying fails then a panic will ensue.
14 tagMust uint8 = 1 << iota
15
16 // Denotes that the program should not panic when the must flag is on and
17 // value is not copied. The program will return an error instead.
18 tagNoPanic
19
20 // Ignore a destination field from being copied to.
21 tagIgnore
22
23 // Denotes that the value as been copied
24 hasCopied
25)
26
27// Option sets copy options
28type Option struct {
29 // setting this value to true will ignore copying zero values of all the fields, including bools, as well as a
30 // struct having all it's fields set to their zero values respectively (see IsZero() in reflect/value.go)
31 IgnoreEmpty bool
32 DeepCopy bool
33}
34
35// Copy copy things
36func Copy(toValue interface{}, fromValue interface{}) (err error) {
37 return copier(toValue, fromValue, Option{})
38}
39
40// CopyWithOption copy with option
41func CopyWithOption(toValue interface{}, fromValue interface{}, opt Option) (err error) {
42 return copier(toValue, fromValue, opt)
43}
44
45func copier(toValue interface{}, fromValue interface{}, opt Option) (err error) {
46 var (
47 isSlice bool
48 amount = 1
49 from = indirect(reflect.ValueOf(fromValue))
50 to = indirect(reflect.ValueOf(toValue))
51 )
52
53 if !to.CanAddr() {
54 return ErrInvalidCopyDestination
55 }
56
57 // Return is from value is invalid
58 if !from.IsValid() {
59 return ErrInvalidCopyFrom
60 }
61
62 fromType, isPtrFrom := indirectType(from.Type())
63 toType, _ := indirectType(to.Type())
64
65 if fromType.Kind() == reflect.Interface {
66 fromType = reflect.TypeOf(from.Interface())
67 }
68
69 if toType.Kind() == reflect.Interface {
70 toType, _ = indirectType(reflect.TypeOf(to.Interface()))
71 oldTo := to
72 to = reflect.New(reflect.TypeOf(to.Interface())).Elem()
73 defer func() {
74 oldTo.Set(to)
75 }()
76 }
77
78 // Just set it if possible to assign for normal types
79 if from.Kind() != reflect.Slice && from.Kind() != reflect.Struct && from.Kind() != reflect.Map && (from.Type().AssignableTo(to.Type()) || from.Type().ConvertibleTo(to.Type())) {
80 if !isPtrFrom || !opt.DeepCopy {
81 to.Set(from.Convert(to.Type()))
82 } else {
83 fromCopy := reflect.New(from.Type())
84 fromCopy.Set(from.Elem())
85 to.Set(fromCopy.Convert(to.Type()))
86 }
87 return
88 }
89
90 if fromType.Kind() == reflect.Map && toType.Kind() == reflect.Map {
91 if !fromType.Key().ConvertibleTo(toType.Key()) {
92 return ErrMapKeyNotMatch
93 }
94
95 if to.IsNil() {
96 to.Set(reflect.MakeMapWithSize(toType, from.Len()))
97 }
98
99 for _, k := range from.MapKeys() {
100 toKey := indirect(reflect.New(toType.Key()))
101 if !set(toKey, k, opt.DeepCopy) {
102 return fmt.Errorf("%w map, old key: %v, new key: %v", ErrNotSupported, k.Type(), toType.Key())
103 }
104
105 elemType, _ := indirectType(toType.Elem())
106 toValue := indirect(reflect.New(elemType))
107 if !set(toValue, from.MapIndex(k), opt.DeepCopy) {
108 if err = copier(toValue.Addr().Interface(), from.MapIndex(k).Interface(), opt); err != nil {
109 return err
110 }
111 }
112
113 for {
114 if elemType == toType.Elem() {
115 to.SetMapIndex(toKey, toValue)
116 break
117 }
118 elemType = reflect.PtrTo(elemType)
119 toValue = toValue.Addr()
120 }
121 }
122 return
123 }
124
125 if from.Kind() == reflect.Slice && to.Kind() == reflect.Slice && fromType.ConvertibleTo(toType) {
126 if to.IsNil() {
127 slice := reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), from.Len(), from.Cap())
128 to.Set(slice)
129 }
130
131 for i := 0; i < from.Len(); i++ {
132 if to.Len() < i+1 {
133 to.Set(reflect.Append(to, reflect.New(to.Type().Elem()).Elem()))
134 }
135
136 if !set(to.Index(i), from.Index(i), opt.DeepCopy) {
137 err = CopyWithOption(to.Index(i).Addr().Interface(), from.Index(i).Interface(), opt)
138 if err != nil {
139 continue
140 }
141 }
142 }
143 return
144 }
145
146 if fromType.Kind() != reflect.Struct || toType.Kind() != reflect.Struct {
147 // skip not supported type
148 return
149 }
150
151 if to.Kind() == reflect.Slice {
152 isSlice = true
153 if from.Kind() == reflect.Slice {
154 amount = from.Len()
155 }
156 }
157
158 for i := 0; i < amount; i++ {
159 var dest, source reflect.Value
160
161 if isSlice {
162 // source
163 if from.Kind() == reflect.Slice {
164 source = indirect(from.Index(i))
165 } else {
166 source = indirect(from)
167 }
168 // dest
169 dest = indirect(reflect.New(toType).Elem())
170 } else {
171 source = indirect(from)
172 dest = indirect(to)
173 }
174
175 destKind := dest.Kind()
176 initDest := false
177 if destKind == reflect.Interface {
178 initDest = true
179 dest = indirect(reflect.New(toType))
180 }
181
182 // Get tag options
183 tagBitFlags := map[string]uint8{}
184 if dest.IsValid() {
185 tagBitFlags = getBitFlags(toType)
186 }
187
188 // check source
189 if source.IsValid() {
190 // Copy from source field to dest field or method
191 fromTypeFields := deepFields(fromType)
192 for _, field := range fromTypeFields {
193 name := field.Name
194
195 // Get bit flags for field
196 fieldFlags, _ := tagBitFlags[name]
197
198 // Check if we should ignore copying
199 if (fieldFlags & tagIgnore) != 0 {
200 continue
201 }
202
203 if fromField := source.FieldByName(name); fromField.IsValid() && !shouldIgnore(fromField, opt.IgnoreEmpty) {
204 // process for nested anonymous field
205 destFieldNotSet := false
206 if f, ok := dest.Type().FieldByName(name); ok {
207 for idx := range f.Index {
208 destField := dest.FieldByIndex(f.Index[:idx+1])
209
210 if destField.Kind() != reflect.Ptr {
211 continue
212 }
213
214 if !destField.IsNil() {
215 continue
216 }
217 if !destField.CanSet() {
218 destFieldNotSet = true
219 break
220 }
221
222 // destField is a nil pointer that can be set
223 newValue := reflect.New(destField.Type().Elem())
224 destField.Set(newValue)
225 }
226 }
227
228 if destFieldNotSet {
229 break
230 }
231
232 toField := dest.FieldByName(name)
233 if toField.IsValid() {
234 if toField.CanSet() {
235 if !set(toField, fromField, opt.DeepCopy) {
236 if err := copier(toField.Addr().Interface(), fromField.Interface(), opt); err != nil {
237 return err
238 }
239 }
240 if fieldFlags != 0 {
241 // Note that a copy was made
242 tagBitFlags[name] = fieldFlags | hasCopied
243 }
244 }
245 } else {
246 // try to set to method
247 var toMethod reflect.Value
248 if dest.CanAddr() {
249 toMethod = dest.Addr().MethodByName(name)
250 } else {
251 toMethod = dest.MethodByName(name)
252 }
253
254 if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && fromField.Type().AssignableTo(toMethod.Type().In(0)) {
255 toMethod.Call([]reflect.Value{fromField})
256 }
257 }
258 }
259 }
260
261 // Copy from from method to dest field
262 for _, field := range deepFields(toType) {
263 name := field.Name
264
265 var fromMethod reflect.Value
266 if source.CanAddr() {
267 fromMethod = source.Addr().MethodByName(name)
268 } else {
269 fromMethod = source.MethodByName(name)
270 }
271
272 if fromMethod.IsValid() && fromMethod.Type().NumIn() == 0 && fromMethod.Type().NumOut() == 1 && !shouldIgnore(fromMethod, opt.IgnoreEmpty) {
273 if toField := dest.FieldByName(name); toField.IsValid() && toField.CanSet() {
274 values := fromMethod.Call([]reflect.Value{})
275 if len(values) >= 1 {
276 set(toField, values[0], opt.DeepCopy)
277 }
278 }
279 }
280 }
281 }
282
283 if isSlice {
284 if dest.Addr().Type().AssignableTo(to.Type().Elem()) {
285 if to.Len() < i+1 {
286 to.Set(reflect.Append(to, dest.Addr()))
287 } else {
288 set(to.Index(i), dest.Addr(), opt.DeepCopy)
289 }
290 } else if dest.Type().AssignableTo(to.Type().Elem()) {
291 if to.Len() < i+1 {
292 to.Set(reflect.Append(to, dest))
293 } else {
294 set(to.Index(i), dest, opt.DeepCopy)
295 }
296 }
297 } else if initDest {
298 to.Set(dest)
299 }
300
301 err = checkBitFlags(tagBitFlags)
302 }
303
304 return
305}
306
307func shouldIgnore(v reflect.Value, ignoreEmpty bool) bool {
308 if !ignoreEmpty {
309 return false
310 }
311
312 return v.IsZero()
313}
314
315func deepFields(reflectType reflect.Type) []reflect.StructField {
316 if reflectType, _ = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
317 fields := make([]reflect.StructField, 0, reflectType.NumField())
318
319 for i := 0; i < reflectType.NumField(); i++ {
320 v := reflectType.Field(i)
321 if v.Anonymous {
322 fields = append(fields, deepFields(v.Type)...)
323 } else {
324 fields = append(fields, v)
325 }
326 }
327
328 return fields
329 }
330
331 return nil
332}
333
334func indirect(reflectValue reflect.Value) reflect.Value {
335 for reflectValue.Kind() == reflect.Ptr {
336 reflectValue = reflectValue.Elem()
337 }
338 return reflectValue
339}
340
341func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) {
342 for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
343 reflectType = reflectType.Elem()
344 isPtr = true
345 }
346 return reflectType, isPtr
347}
348
349func set(to, from reflect.Value, deepCopy bool) bool {
350 if from.IsValid() {
351 if to.Kind() == reflect.Ptr {
352 // set `to` to nil if from is nil
353 if from.Kind() == reflect.Ptr && from.IsNil() {
354 to.Set(reflect.Zero(to.Type()))
355 return true
356 } else if to.IsNil() {
357 // `from` -> `to`
358 // sql.NullString -> *string
359 if fromValuer, ok := driverValuer(from); ok {
360 v, err := fromValuer.Value()
361 if err != nil {
362 return false
363 }
364 // if `from` is not valid do nothing with `to`
365 if v == nil {
366 return true
367 }
368 }
369 // allocate new `to` variable with default value (eg. *string -> new(string))
370 to.Set(reflect.New(to.Type().Elem()))
371 }
372 // depointer `to`
373 to = to.Elem()
374 }
375
376 if deepCopy {
377 toKind := to.Kind()
378 if toKind == reflect.Interface && to.IsNil() {
379 if reflect.TypeOf(from.Interface()) != nil {
380 to.Set(reflect.New(reflect.TypeOf(from.Interface())).Elem())
381 toKind = reflect.TypeOf(to.Interface()).Kind()
382 }
383 }
384 if toKind == reflect.Struct || toKind == reflect.Map || toKind == reflect.Slice {
385 return false
386 }
387 }
388
389 if from.Type().ConvertibleTo(to.Type()) {
390 to.Set(from.Convert(to.Type()))
391 } else if toScanner, ok := to.Addr().Interface().(sql.Scanner); ok {
392 // `from` -> `to`
393 // *string -> sql.NullString
394 if from.Kind() == reflect.Ptr {
395 // if `from` is nil do nothing with `to`
396 if from.IsNil() {
397 return true
398 }
399 // depointer `from`
400 from = indirect(from)
401 }
402 // `from` -> `to`
403 // string -> sql.NullString
404 // set `to` by invoking method Scan(`from`)
405 err := toScanner.Scan(from.Interface())
406 if err != nil {
407 return false
408 }
409 } else if fromValuer, ok := driverValuer(from); ok {
410 // `from` -> `to`
411 // sql.NullString -> string
412 v, err := fromValuer.Value()
413 if err != nil {
414 return false
415 }
416 // if `from` is not valid do nothing with `to`
417 if v == nil {
418 return true
419 }
420 rv := reflect.ValueOf(v)
421 if rv.Type().AssignableTo(to.Type()) {
422 to.Set(rv)
423 }
424 } else if from.Kind() == reflect.Ptr {
425 return set(to, from.Elem(), deepCopy)
426 } else {
427 return false
428 }
429 }
430
431 return true
432}
433
434// parseTags Parses struct tags and returns uint8 bit flags.
435func parseTags(tag string) (flags uint8) {
436 for _, t := range strings.Split(tag, ",") {
437 switch t {
438 case "-":
439 flags = tagIgnore
440 return
441 case "must":
442 flags = flags | tagMust
443 case "nopanic":
444 flags = flags | tagNoPanic
445 }
446 }
447 return
448}
449
450// getBitFlags Parses struct tags for bit flags.
451func getBitFlags(toType reflect.Type) map[string]uint8 {
452 flags := map[string]uint8{}
453 toTypeFields := deepFields(toType)
454
455 // Get a list dest of tags
456 for _, field := range toTypeFields {
457 tags := field.Tag.Get("copier")
458 if tags != "" {
459 flags[field.Name] = parseTags(tags)
460 }
461 }
462 return flags
463}
464
465// checkBitFlags Checks flags for error or panic conditions.
466func checkBitFlags(flagsList map[string]uint8) (err error) {
467 // Check flag conditions were met
468 for name, flags := range flagsList {
469 if flags&hasCopied == 0 {
470 switch {
471 case flags&tagMust != 0 && flags&tagNoPanic != 0:
472 err = fmt.Errorf("field %s has must tag but was not copied", name)
473 return
474 case flags&(tagMust) != 0:
475 panic(fmt.Sprintf("Field %s has must tag but was not copied", name))
476 }
477 }
478 }
479 return
480}
481
482func driverValuer(v reflect.Value) (i driver.Valuer, ok bool) {
483
484 if !v.CanAddr() {
485 i, ok = v.Interface().(driver.Valuer)
486 return
487 }
488
489 i, ok = v.Addr().Interface().(driver.Valuer)
490 return
491}