Matteo Scandolo | a6a3aee | 2019-11-26 13:30:14 -0700 | [diff] [blame] | 1 | // Copyright 2013 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | package yaml |
| 5 | |
| 6 | import ( |
| 7 | "bytes" |
| 8 | "encoding" |
| 9 | "encoding/json" |
| 10 | "reflect" |
| 11 | "sort" |
| 12 | "strings" |
| 13 | "sync" |
| 14 | "unicode" |
| 15 | "unicode/utf8" |
| 16 | ) |
| 17 | |
| 18 | // indirect walks down v allocating pointers as needed, |
| 19 | // until it gets to a non-pointer. |
| 20 | // if it encounters an Unmarshaler, indirect stops and returns that. |
| 21 | // if decodingNull is true, indirect stops at the last pointer so it can be set to nil. |
| 22 | func indirect(v reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { |
| 23 | // If v is a named type and is addressable, |
| 24 | // start with its address, so that if the type has pointer methods, |
| 25 | // we find them. |
| 26 | if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { |
| 27 | v = v.Addr() |
| 28 | } |
| 29 | for { |
| 30 | // Load value from interface, but only if the result will be |
| 31 | // usefully addressable. |
| 32 | if v.Kind() == reflect.Interface && !v.IsNil() { |
| 33 | e := v.Elem() |
| 34 | if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { |
| 35 | v = e |
| 36 | continue |
| 37 | } |
| 38 | } |
| 39 | |
| 40 | if v.Kind() != reflect.Ptr { |
| 41 | break |
| 42 | } |
| 43 | |
| 44 | if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { |
| 45 | break |
| 46 | } |
| 47 | if v.IsNil() { |
| 48 | if v.CanSet() { |
| 49 | v.Set(reflect.New(v.Type().Elem())) |
| 50 | } else { |
| 51 | v = reflect.New(v.Type().Elem()) |
| 52 | } |
| 53 | } |
| 54 | if v.Type().NumMethod() > 0 { |
| 55 | if u, ok := v.Interface().(json.Unmarshaler); ok { |
| 56 | return u, nil, reflect.Value{} |
| 57 | } |
| 58 | if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { |
| 59 | return nil, u, reflect.Value{} |
| 60 | } |
| 61 | } |
| 62 | v = v.Elem() |
| 63 | } |
| 64 | return nil, nil, v |
| 65 | } |
| 66 | |
| 67 | // A field represents a single field found in a struct. |
| 68 | type field struct { |
| 69 | name string |
| 70 | nameBytes []byte // []byte(name) |
| 71 | equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent |
| 72 | |
| 73 | tag bool |
| 74 | index []int |
| 75 | typ reflect.Type |
| 76 | omitEmpty bool |
| 77 | quoted bool |
| 78 | } |
| 79 | |
| 80 | func fillField(f field) field { |
| 81 | f.nameBytes = []byte(f.name) |
| 82 | f.equalFold = foldFunc(f.nameBytes) |
| 83 | return f |
| 84 | } |
| 85 | |
| 86 | // byName sorts field by name, breaking ties with depth, |
| 87 | // then breaking ties with "name came from json tag", then |
| 88 | // breaking ties with index sequence. |
| 89 | type byName []field |
| 90 | |
| 91 | func (x byName) Len() int { return len(x) } |
| 92 | |
| 93 | func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } |
| 94 | |
| 95 | func (x byName) Less(i, j int) bool { |
| 96 | if x[i].name != x[j].name { |
| 97 | return x[i].name < x[j].name |
| 98 | } |
| 99 | if len(x[i].index) != len(x[j].index) { |
| 100 | return len(x[i].index) < len(x[j].index) |
| 101 | } |
| 102 | if x[i].tag != x[j].tag { |
| 103 | return x[i].tag |
| 104 | } |
| 105 | return byIndex(x).Less(i, j) |
| 106 | } |
| 107 | |
| 108 | // byIndex sorts field by index sequence. |
| 109 | type byIndex []field |
| 110 | |
| 111 | func (x byIndex) Len() int { return len(x) } |
| 112 | |
| 113 | func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } |
| 114 | |
| 115 | func (x byIndex) Less(i, j int) bool { |
| 116 | for k, xik := range x[i].index { |
| 117 | if k >= len(x[j].index) { |
| 118 | return false |
| 119 | } |
| 120 | if xik != x[j].index[k] { |
| 121 | return xik < x[j].index[k] |
| 122 | } |
| 123 | } |
| 124 | return len(x[i].index) < len(x[j].index) |
| 125 | } |
| 126 | |
| 127 | // typeFields returns a list of fields that JSON should recognize for the given type. |
| 128 | // The algorithm is breadth-first search over the set of structs to include - the top struct |
| 129 | // and then any reachable anonymous structs. |
| 130 | func typeFields(t reflect.Type) []field { |
| 131 | // Anonymous fields to explore at the current level and the next. |
| 132 | current := []field{} |
| 133 | next := []field{{typ: t}} |
| 134 | |
| 135 | // Count of queued names for current level and the next. |
| 136 | count := map[reflect.Type]int{} |
| 137 | nextCount := map[reflect.Type]int{} |
| 138 | |
| 139 | // Types already visited at an earlier level. |
| 140 | visited := map[reflect.Type]bool{} |
| 141 | |
| 142 | // Fields found. |
| 143 | var fields []field |
| 144 | |
| 145 | for len(next) > 0 { |
| 146 | current, next = next, current[:0] |
| 147 | count, nextCount = nextCount, map[reflect.Type]int{} |
| 148 | |
| 149 | for _, f := range current { |
| 150 | if visited[f.typ] { |
| 151 | continue |
| 152 | } |
| 153 | visited[f.typ] = true |
| 154 | |
| 155 | // Scan f.typ for fields to include. |
| 156 | for i := 0; i < f.typ.NumField(); i++ { |
| 157 | sf := f.typ.Field(i) |
| 158 | if sf.PkgPath != "" { // unexported |
| 159 | continue |
| 160 | } |
| 161 | tag := sf.Tag.Get("json") |
| 162 | if tag == "-" { |
| 163 | continue |
| 164 | } |
| 165 | name, opts := parseTag(tag) |
| 166 | if !isValidTag(name) { |
| 167 | name = "" |
| 168 | } |
| 169 | index := make([]int, len(f.index)+1) |
| 170 | copy(index, f.index) |
| 171 | index[len(f.index)] = i |
| 172 | |
| 173 | ft := sf.Type |
| 174 | if ft.Name() == "" && ft.Kind() == reflect.Ptr { |
| 175 | // Follow pointer. |
| 176 | ft = ft.Elem() |
| 177 | } |
| 178 | |
| 179 | // Record found field and index sequence. |
| 180 | if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { |
| 181 | tagged := name != "" |
| 182 | if name == "" { |
| 183 | name = sf.Name |
| 184 | } |
| 185 | fields = append(fields, fillField(field{ |
| 186 | name: name, |
| 187 | tag: tagged, |
| 188 | index: index, |
| 189 | typ: ft, |
| 190 | omitEmpty: opts.Contains("omitempty"), |
| 191 | quoted: opts.Contains("string"), |
| 192 | })) |
| 193 | if count[f.typ] > 1 { |
| 194 | // If there were multiple instances, add a second, |
| 195 | // so that the annihilation code will see a duplicate. |
| 196 | // It only cares about the distinction between 1 or 2, |
| 197 | // so don't bother generating any more copies. |
| 198 | fields = append(fields, fields[len(fields)-1]) |
| 199 | } |
| 200 | continue |
| 201 | } |
| 202 | |
| 203 | // Record new anonymous struct to explore in next round. |
| 204 | nextCount[ft]++ |
| 205 | if nextCount[ft] == 1 { |
| 206 | next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft})) |
| 207 | } |
| 208 | } |
| 209 | } |
| 210 | } |
| 211 | |
| 212 | sort.Sort(byName(fields)) |
| 213 | |
| 214 | // Delete all fields that are hidden by the Go rules for embedded fields, |
| 215 | // except that fields with JSON tags are promoted. |
| 216 | |
| 217 | // The fields are sorted in primary order of name, secondary order |
| 218 | // of field index length. Loop over names; for each name, delete |
| 219 | // hidden fields by choosing the one dominant field that survives. |
| 220 | out := fields[:0] |
| 221 | for advance, i := 0, 0; i < len(fields); i += advance { |
| 222 | // One iteration per name. |
| 223 | // Find the sequence of fields with the name of this first field. |
| 224 | fi := fields[i] |
| 225 | name := fi.name |
| 226 | for advance = 1; i+advance < len(fields); advance++ { |
| 227 | fj := fields[i+advance] |
| 228 | if fj.name != name { |
| 229 | break |
| 230 | } |
| 231 | } |
| 232 | if advance == 1 { // Only one field with this name |
| 233 | out = append(out, fi) |
| 234 | continue |
| 235 | } |
| 236 | dominant, ok := dominantField(fields[i : i+advance]) |
| 237 | if ok { |
| 238 | out = append(out, dominant) |
| 239 | } |
| 240 | } |
| 241 | |
| 242 | fields = out |
| 243 | sort.Sort(byIndex(fields)) |
| 244 | |
| 245 | return fields |
| 246 | } |
| 247 | |
| 248 | // dominantField looks through the fields, all of which are known to |
| 249 | // have the same name, to find the single field that dominates the |
| 250 | // others using Go's embedding rules, modified by the presence of |
| 251 | // JSON tags. If there are multiple top-level fields, the boolean |
| 252 | // will be false: This condition is an error in Go and we skip all |
| 253 | // the fields. |
| 254 | func dominantField(fields []field) (field, bool) { |
| 255 | // The fields are sorted in increasing index-length order. The winner |
| 256 | // must therefore be one with the shortest index length. Drop all |
| 257 | // longer entries, which is easy: just truncate the slice. |
| 258 | length := len(fields[0].index) |
| 259 | tagged := -1 // Index of first tagged field. |
| 260 | for i, f := range fields { |
| 261 | if len(f.index) > length { |
| 262 | fields = fields[:i] |
| 263 | break |
| 264 | } |
| 265 | if f.tag { |
| 266 | if tagged >= 0 { |
| 267 | // Multiple tagged fields at the same level: conflict. |
| 268 | // Return no field. |
| 269 | return field{}, false |
| 270 | } |
| 271 | tagged = i |
| 272 | } |
| 273 | } |
| 274 | if tagged >= 0 { |
| 275 | return fields[tagged], true |
| 276 | } |
| 277 | // All remaining fields have the same length. If there's more than one, |
| 278 | // we have a conflict (two fields named "X" at the same level) and we |
| 279 | // return no field. |
| 280 | if len(fields) > 1 { |
| 281 | return field{}, false |
| 282 | } |
| 283 | return fields[0], true |
| 284 | } |
| 285 | |
| 286 | var fieldCache struct { |
| 287 | sync.RWMutex |
| 288 | m map[reflect.Type][]field |
| 289 | } |
| 290 | |
| 291 | // cachedTypeFields is like typeFields but uses a cache to avoid repeated work. |
| 292 | func cachedTypeFields(t reflect.Type) []field { |
| 293 | fieldCache.RLock() |
| 294 | f := fieldCache.m[t] |
| 295 | fieldCache.RUnlock() |
| 296 | if f != nil { |
| 297 | return f |
| 298 | } |
| 299 | |
| 300 | // Compute fields without lock. |
| 301 | // Might duplicate effort but won't hold other computations back. |
| 302 | f = typeFields(t) |
| 303 | if f == nil { |
| 304 | f = []field{} |
| 305 | } |
| 306 | |
| 307 | fieldCache.Lock() |
| 308 | if fieldCache.m == nil { |
| 309 | fieldCache.m = map[reflect.Type][]field{} |
| 310 | } |
| 311 | fieldCache.m[t] = f |
| 312 | fieldCache.Unlock() |
| 313 | return f |
| 314 | } |
| 315 | |
| 316 | func isValidTag(s string) bool { |
| 317 | if s == "" { |
| 318 | return false |
| 319 | } |
| 320 | for _, c := range s { |
| 321 | switch { |
| 322 | case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): |
| 323 | // Backslash and quote chars are reserved, but |
| 324 | // otherwise any punctuation chars are allowed |
| 325 | // in a tag name. |
| 326 | default: |
| 327 | if !unicode.IsLetter(c) && !unicode.IsDigit(c) { |
| 328 | return false |
| 329 | } |
| 330 | } |
| 331 | } |
| 332 | return true |
| 333 | } |
| 334 | |
| 335 | const ( |
| 336 | caseMask = ^byte(0x20) // Mask to ignore case in ASCII. |
| 337 | kelvin = '\u212a' |
| 338 | smallLongEss = '\u017f' |
| 339 | ) |
| 340 | |
| 341 | // foldFunc returns one of four different case folding equivalence |
| 342 | // functions, from most general (and slow) to fastest: |
| 343 | // |
| 344 | // 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8 |
| 345 | // 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S') |
| 346 | // 3) asciiEqualFold, no special, but includes non-letters (including _) |
| 347 | // 4) simpleLetterEqualFold, no specials, no non-letters. |
| 348 | // |
| 349 | // The letters S and K are special because they map to 3 runes, not just 2: |
| 350 | // * S maps to s and to U+017F 'ſ' Latin small letter long s |
| 351 | // * k maps to K and to U+212A 'K' Kelvin sign |
| 352 | // See http://play.golang.org/p/tTxjOc0OGo |
| 353 | // |
| 354 | // The returned function is specialized for matching against s and |
| 355 | // should only be given s. It's not curried for performance reasons. |
| 356 | func foldFunc(s []byte) func(s, t []byte) bool { |
| 357 | nonLetter := false |
| 358 | special := false // special letter |
| 359 | for _, b := range s { |
| 360 | if b >= utf8.RuneSelf { |
| 361 | return bytes.EqualFold |
| 362 | } |
| 363 | upper := b & caseMask |
| 364 | if upper < 'A' || upper > 'Z' { |
| 365 | nonLetter = true |
| 366 | } else if upper == 'K' || upper == 'S' { |
| 367 | // See above for why these letters are special. |
| 368 | special = true |
| 369 | } |
| 370 | } |
| 371 | if special { |
| 372 | return equalFoldRight |
| 373 | } |
| 374 | if nonLetter { |
| 375 | return asciiEqualFold |
| 376 | } |
| 377 | return simpleLetterEqualFold |
| 378 | } |
| 379 | |
| 380 | // equalFoldRight is a specialization of bytes.EqualFold when s is |
| 381 | // known to be all ASCII (including punctuation), but contains an 's', |
| 382 | // 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t. |
| 383 | // See comments on foldFunc. |
| 384 | func equalFoldRight(s, t []byte) bool { |
| 385 | for _, sb := range s { |
| 386 | if len(t) == 0 { |
| 387 | return false |
| 388 | } |
| 389 | tb := t[0] |
| 390 | if tb < utf8.RuneSelf { |
| 391 | if sb != tb { |
| 392 | sbUpper := sb & caseMask |
| 393 | if 'A' <= sbUpper && sbUpper <= 'Z' { |
| 394 | if sbUpper != tb&caseMask { |
| 395 | return false |
| 396 | } |
| 397 | } else { |
| 398 | return false |
| 399 | } |
| 400 | } |
| 401 | t = t[1:] |
| 402 | continue |
| 403 | } |
| 404 | // sb is ASCII and t is not. t must be either kelvin |
| 405 | // sign or long s; sb must be s, S, k, or K. |
| 406 | tr, size := utf8.DecodeRune(t) |
| 407 | switch sb { |
| 408 | case 's', 'S': |
| 409 | if tr != smallLongEss { |
| 410 | return false |
| 411 | } |
| 412 | case 'k', 'K': |
| 413 | if tr != kelvin { |
| 414 | return false |
| 415 | } |
| 416 | default: |
| 417 | return false |
| 418 | } |
| 419 | t = t[size:] |
| 420 | |
| 421 | } |
| 422 | if len(t) > 0 { |
| 423 | return false |
| 424 | } |
| 425 | return true |
| 426 | } |
| 427 | |
| 428 | // asciiEqualFold is a specialization of bytes.EqualFold for use when |
| 429 | // s is all ASCII (but may contain non-letters) and contains no |
| 430 | // special-folding letters. |
| 431 | // See comments on foldFunc. |
| 432 | func asciiEqualFold(s, t []byte) bool { |
| 433 | if len(s) != len(t) { |
| 434 | return false |
| 435 | } |
| 436 | for i, sb := range s { |
| 437 | tb := t[i] |
| 438 | if sb == tb { |
| 439 | continue |
| 440 | } |
| 441 | if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') { |
| 442 | if sb&caseMask != tb&caseMask { |
| 443 | return false |
| 444 | } |
| 445 | } else { |
| 446 | return false |
| 447 | } |
| 448 | } |
| 449 | return true |
| 450 | } |
| 451 | |
| 452 | // simpleLetterEqualFold is a specialization of bytes.EqualFold for |
| 453 | // use when s is all ASCII letters (no underscores, etc) and also |
| 454 | // doesn't contain 'k', 'K', 's', or 'S'. |
| 455 | // See comments on foldFunc. |
| 456 | func simpleLetterEqualFold(s, t []byte) bool { |
| 457 | if len(s) != len(t) { |
| 458 | return false |
| 459 | } |
| 460 | for i, b := range s { |
| 461 | if b&caseMask != t[i]&caseMask { |
| 462 | return false |
| 463 | } |
| 464 | } |
| 465 | return true |
| 466 | } |
| 467 | |
| 468 | // tagOptions is the string following a comma in a struct field's "json" |
| 469 | // tag, or the empty string. It does not include the leading comma. |
| 470 | type tagOptions string |
| 471 | |
| 472 | // parseTag splits a struct field's json tag into its name and |
| 473 | // comma-separated options. |
| 474 | func parseTag(tag string) (string, tagOptions) { |
| 475 | if idx := strings.Index(tag, ","); idx != -1 { |
| 476 | return tag[:idx], tagOptions(tag[idx+1:]) |
| 477 | } |
| 478 | return tag, tagOptions("") |
| 479 | } |
| 480 | |
| 481 | // Contains reports whether a comma-separated list of options |
| 482 | // contains a particular substr flag. substr must be surrounded by a |
| 483 | // string boundary or commas. |
| 484 | func (o tagOptions) Contains(optionName string) bool { |
| 485 | if len(o) == 0 { |
| 486 | return false |
| 487 | } |
| 488 | s := string(o) |
| 489 | for s != "" { |
| 490 | var next string |
| 491 | i := strings.Index(s, ",") |
| 492 | if i >= 0 { |
| 493 | s, next = s[:i], s[i+1:] |
| 494 | } |
| 495 | if s == optionName { |
| 496 | return true |
| 497 | } |
| 498 | s = next |
| 499 | } |
| 500 | return false |
| 501 | } |