Scott Baker | e7144bc | 2019-10-01 14:16:47 -0700 | [diff] [blame^] | 1 | // Copyright 2013 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | package yaml |
| 6 | |
| 7 | import ( |
| 8 | "bytes" |
| 9 | "encoding" |
| 10 | "encoding/json" |
| 11 | "reflect" |
| 12 | "sort" |
| 13 | "strings" |
| 14 | "sync" |
| 15 | "unicode" |
| 16 | "unicode/utf8" |
| 17 | ) |
| 18 | |
| 19 | // indirect walks down v allocating pointers as needed, |
| 20 | // until it gets to a non-pointer. |
| 21 | // if it encounters an Unmarshaler, indirect stops and returns that. |
| 22 | // if decodingNull is true, indirect stops at the last pointer so it can be set to nil. |
| 23 | func indirect(v reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { |
| 24 | // If v is a named type and is addressable, |
| 25 | // start with its address, so that if the type has pointer methods, |
| 26 | // we find them. |
| 27 | if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { |
| 28 | v = v.Addr() |
| 29 | } |
| 30 | for { |
| 31 | // Load value from interface, but only if the result will be |
| 32 | // usefully addressable. |
| 33 | if v.Kind() == reflect.Interface && !v.IsNil() { |
| 34 | e := v.Elem() |
| 35 | if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) { |
| 36 | v = e |
| 37 | continue |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | if v.Kind() != reflect.Ptr { |
| 42 | break |
| 43 | } |
| 44 | |
| 45 | if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() { |
| 46 | break |
| 47 | } |
| 48 | if v.IsNil() { |
| 49 | if v.CanSet() { |
| 50 | v.Set(reflect.New(v.Type().Elem())) |
| 51 | } else { |
| 52 | v = reflect.New(v.Type().Elem()) |
| 53 | } |
| 54 | } |
| 55 | if v.Type().NumMethod() > 0 { |
| 56 | if u, ok := v.Interface().(json.Unmarshaler); ok { |
| 57 | return u, nil, reflect.Value{} |
| 58 | } |
| 59 | if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { |
| 60 | return nil, u, reflect.Value{} |
| 61 | } |
| 62 | } |
| 63 | v = v.Elem() |
| 64 | } |
| 65 | return nil, nil, v |
| 66 | } |
| 67 | |
| 68 | // A field represents a single field found in a struct. |
| 69 | type field struct { |
| 70 | name string |
| 71 | nameBytes []byte // []byte(name) |
| 72 | equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent |
| 73 | |
| 74 | tag bool |
| 75 | index []int |
| 76 | typ reflect.Type |
| 77 | omitEmpty bool |
| 78 | quoted bool |
| 79 | } |
| 80 | |
| 81 | func fillField(f field) field { |
| 82 | f.nameBytes = []byte(f.name) |
| 83 | f.equalFold = foldFunc(f.nameBytes) |
| 84 | return f |
| 85 | } |
| 86 | |
| 87 | // byName sorts field by name, breaking ties with depth, |
| 88 | // then breaking ties with "name came from json tag", then |
| 89 | // breaking ties with index sequence. |
| 90 | type byName []field |
| 91 | |
| 92 | func (x byName) Len() int { return len(x) } |
| 93 | |
| 94 | func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } |
| 95 | |
| 96 | func (x byName) Less(i, j int) bool { |
| 97 | if x[i].name != x[j].name { |
| 98 | return x[i].name < x[j].name |
| 99 | } |
| 100 | if len(x[i].index) != len(x[j].index) { |
| 101 | return len(x[i].index) < len(x[j].index) |
| 102 | } |
| 103 | if x[i].tag != x[j].tag { |
| 104 | return x[i].tag |
| 105 | } |
| 106 | return byIndex(x).Less(i, j) |
| 107 | } |
| 108 | |
| 109 | // byIndex sorts field by index sequence. |
| 110 | type byIndex []field |
| 111 | |
| 112 | func (x byIndex) Len() int { return len(x) } |
| 113 | |
| 114 | func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } |
| 115 | |
| 116 | func (x byIndex) Less(i, j int) bool { |
| 117 | for k, xik := range x[i].index { |
| 118 | if k >= len(x[j].index) { |
| 119 | return false |
| 120 | } |
| 121 | if xik != x[j].index[k] { |
| 122 | return xik < x[j].index[k] |
| 123 | } |
| 124 | } |
| 125 | return len(x[i].index) < len(x[j].index) |
| 126 | } |
| 127 | |
| 128 | // typeFields returns a list of fields that JSON should recognize for the given type. |
| 129 | // The algorithm is breadth-first search over the set of structs to include - the top struct |
| 130 | // and then any reachable anonymous structs. |
| 131 | func typeFields(t reflect.Type) []field { |
| 132 | // Anonymous fields to explore at the current level and the next. |
| 133 | current := []field{} |
| 134 | next := []field{{typ: t}} |
| 135 | |
| 136 | // Count of queued names for current level and the next. |
| 137 | count := map[reflect.Type]int{} |
| 138 | nextCount := map[reflect.Type]int{} |
| 139 | |
| 140 | // Types already visited at an earlier level. |
| 141 | visited := map[reflect.Type]bool{} |
| 142 | |
| 143 | // Fields found. |
| 144 | var fields []field |
| 145 | |
| 146 | for len(next) > 0 { |
| 147 | current, next = next, current[:0] |
| 148 | count, nextCount = nextCount, map[reflect.Type]int{} |
| 149 | |
| 150 | for _, f := range current { |
| 151 | if visited[f.typ] { |
| 152 | continue |
| 153 | } |
| 154 | visited[f.typ] = true |
| 155 | |
| 156 | // Scan f.typ for fields to include. |
| 157 | for i := 0; i < f.typ.NumField(); i++ { |
| 158 | sf := f.typ.Field(i) |
| 159 | if sf.PkgPath != "" { // unexported |
| 160 | continue |
| 161 | } |
| 162 | tag := sf.Tag.Get("json") |
| 163 | if tag == "-" { |
| 164 | continue |
| 165 | } |
| 166 | name, opts := parseTag(tag) |
| 167 | if !isValidTag(name) { |
| 168 | name = "" |
| 169 | } |
| 170 | index := make([]int, len(f.index)+1) |
| 171 | copy(index, f.index) |
| 172 | index[len(f.index)] = i |
| 173 | |
| 174 | ft := sf.Type |
| 175 | if ft.Name() == "" && ft.Kind() == reflect.Ptr { |
| 176 | // Follow pointer. |
| 177 | ft = ft.Elem() |
| 178 | } |
| 179 | |
| 180 | // Record found field and index sequence. |
| 181 | if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { |
| 182 | tagged := name != "" |
| 183 | if name == "" { |
| 184 | name = sf.Name |
| 185 | } |
| 186 | fields = append(fields, fillField(field{ |
| 187 | name: name, |
| 188 | tag: tagged, |
| 189 | index: index, |
| 190 | typ: ft, |
| 191 | omitEmpty: opts.Contains("omitempty"), |
| 192 | quoted: opts.Contains("string"), |
| 193 | })) |
| 194 | if count[f.typ] > 1 { |
| 195 | // If there were multiple instances, add a second, |
| 196 | // so that the annihilation code will see a duplicate. |
| 197 | // It only cares about the distinction between 1 or 2, |
| 198 | // so don't bother generating any more copies. |
| 199 | fields = append(fields, fields[len(fields)-1]) |
| 200 | } |
| 201 | continue |
| 202 | } |
| 203 | |
| 204 | // Record new anonymous struct to explore in next round. |
| 205 | nextCount[ft]++ |
| 206 | if nextCount[ft] == 1 { |
| 207 | next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft})) |
| 208 | } |
| 209 | } |
| 210 | } |
| 211 | } |
| 212 | |
| 213 | sort.Sort(byName(fields)) |
| 214 | |
| 215 | // Delete all fields that are hidden by the Go rules for embedded fields, |
| 216 | // except that fields with JSON tags are promoted. |
| 217 | |
| 218 | // The fields are sorted in primary order of name, secondary order |
| 219 | // of field index length. Loop over names; for each name, delete |
| 220 | // hidden fields by choosing the one dominant field that survives. |
| 221 | out := fields[:0] |
| 222 | for advance, i := 0, 0; i < len(fields); i += advance { |
| 223 | // One iteration per name. |
| 224 | // Find the sequence of fields with the name of this first field. |
| 225 | fi := fields[i] |
| 226 | name := fi.name |
| 227 | for advance = 1; i+advance < len(fields); advance++ { |
| 228 | fj := fields[i+advance] |
| 229 | if fj.name != name { |
| 230 | break |
| 231 | } |
| 232 | } |
| 233 | if advance == 1 { // Only one field with this name |
| 234 | out = append(out, fi) |
| 235 | continue |
| 236 | } |
| 237 | dominant, ok := dominantField(fields[i : i+advance]) |
| 238 | if ok { |
| 239 | out = append(out, dominant) |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | fields = out |
| 244 | sort.Sort(byIndex(fields)) |
| 245 | |
| 246 | return fields |
| 247 | } |
| 248 | |
| 249 | // dominantField looks through the fields, all of which are known to |
| 250 | // have the same name, to find the single field that dominates the |
| 251 | // others using Go's embedding rules, modified by the presence of |
| 252 | // JSON tags. If there are multiple top-level fields, the boolean |
| 253 | // will be false: This condition is an error in Go and we skip all |
| 254 | // the fields. |
| 255 | func dominantField(fields []field) (field, bool) { |
| 256 | // The fields are sorted in increasing index-length order. The winner |
| 257 | // must therefore be one with the shortest index length. Drop all |
| 258 | // longer entries, which is easy: just truncate the slice. |
| 259 | length := len(fields[0].index) |
| 260 | tagged := -1 // Index of first tagged field. |
| 261 | for i, f := range fields { |
| 262 | if len(f.index) > length { |
| 263 | fields = fields[:i] |
| 264 | break |
| 265 | } |
| 266 | if f.tag { |
| 267 | if tagged >= 0 { |
| 268 | // Multiple tagged fields at the same level: conflict. |
| 269 | // Return no field. |
| 270 | return field{}, false |
| 271 | } |
| 272 | tagged = i |
| 273 | } |
| 274 | } |
| 275 | if tagged >= 0 { |
| 276 | return fields[tagged], true |
| 277 | } |
| 278 | // All remaining fields have the same length. If there's more than one, |
| 279 | // we have a conflict (two fields named "X" at the same level) and we |
| 280 | // return no field. |
| 281 | if len(fields) > 1 { |
| 282 | return field{}, false |
| 283 | } |
| 284 | return fields[0], true |
| 285 | } |
| 286 | |
| 287 | var fieldCache struct { |
| 288 | sync.RWMutex |
| 289 | m map[reflect.Type][]field |
| 290 | } |
| 291 | |
| 292 | // cachedTypeFields is like typeFields but uses a cache to avoid repeated work. |
| 293 | func cachedTypeFields(t reflect.Type) []field { |
| 294 | fieldCache.RLock() |
| 295 | f := fieldCache.m[t] |
| 296 | fieldCache.RUnlock() |
| 297 | if f != nil { |
| 298 | return f |
| 299 | } |
| 300 | |
| 301 | // Compute fields without lock. |
| 302 | // Might duplicate effort but won't hold other computations back. |
| 303 | f = typeFields(t) |
| 304 | if f == nil { |
| 305 | f = []field{} |
| 306 | } |
| 307 | |
| 308 | fieldCache.Lock() |
| 309 | if fieldCache.m == nil { |
| 310 | fieldCache.m = map[reflect.Type][]field{} |
| 311 | } |
| 312 | fieldCache.m[t] = f |
| 313 | fieldCache.Unlock() |
| 314 | return f |
| 315 | } |
| 316 | |
| 317 | func isValidTag(s string) bool { |
| 318 | if s == "" { |
| 319 | return false |
| 320 | } |
| 321 | for _, c := range s { |
| 322 | switch { |
| 323 | case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): |
| 324 | // Backslash and quote chars are reserved, but |
| 325 | // otherwise any punctuation chars are allowed |
| 326 | // in a tag name. |
| 327 | default: |
| 328 | if !unicode.IsLetter(c) && !unicode.IsDigit(c) { |
| 329 | return false |
| 330 | } |
| 331 | } |
| 332 | } |
| 333 | return true |
| 334 | } |
| 335 | |
| 336 | const ( |
| 337 | caseMask = ^byte(0x20) // Mask to ignore case in ASCII. |
| 338 | kelvin = '\u212a' |
| 339 | smallLongEss = '\u017f' |
| 340 | ) |
| 341 | |
| 342 | // foldFunc returns one of four different case folding equivalence |
| 343 | // functions, from most general (and slow) to fastest: |
| 344 | // |
| 345 | // 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8 |
| 346 | // 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S') |
| 347 | // 3) asciiEqualFold, no special, but includes non-letters (including _) |
| 348 | // 4) simpleLetterEqualFold, no specials, no non-letters. |
| 349 | // |
| 350 | // The letters S and K are special because they map to 3 runes, not just 2: |
| 351 | // * S maps to s and to U+017F 'ſ' Latin small letter long s |
| 352 | // * k maps to K and to U+212A 'K' Kelvin sign |
| 353 | // See http://play.golang.org/p/tTxjOc0OGo |
| 354 | // |
| 355 | // The returned function is specialized for matching against s and |
| 356 | // should only be given s. It's not curried for performance reasons. |
| 357 | func foldFunc(s []byte) func(s, t []byte) bool { |
| 358 | nonLetter := false |
| 359 | special := false // special letter |
| 360 | for _, b := range s { |
| 361 | if b >= utf8.RuneSelf { |
| 362 | return bytes.EqualFold |
| 363 | } |
| 364 | upper := b & caseMask |
| 365 | if upper < 'A' || upper > 'Z' { |
| 366 | nonLetter = true |
| 367 | } else if upper == 'K' || upper == 'S' { |
| 368 | // See above for why these letters are special. |
| 369 | special = true |
| 370 | } |
| 371 | } |
| 372 | if special { |
| 373 | return equalFoldRight |
| 374 | } |
| 375 | if nonLetter { |
| 376 | return asciiEqualFold |
| 377 | } |
| 378 | return simpleLetterEqualFold |
| 379 | } |
| 380 | |
| 381 | // equalFoldRight is a specialization of bytes.EqualFold when s is |
| 382 | // known to be all ASCII (including punctuation), but contains an 's', |
| 383 | // 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t. |
| 384 | // See comments on foldFunc. |
| 385 | func equalFoldRight(s, t []byte) bool { |
| 386 | for _, sb := range s { |
| 387 | if len(t) == 0 { |
| 388 | return false |
| 389 | } |
| 390 | tb := t[0] |
| 391 | if tb < utf8.RuneSelf { |
| 392 | if sb != tb { |
| 393 | sbUpper := sb & caseMask |
| 394 | if 'A' <= sbUpper && sbUpper <= 'Z' { |
| 395 | if sbUpper != tb&caseMask { |
| 396 | return false |
| 397 | } |
| 398 | } else { |
| 399 | return false |
| 400 | } |
| 401 | } |
| 402 | t = t[1:] |
| 403 | continue |
| 404 | } |
| 405 | // sb is ASCII and t is not. t must be either kelvin |
| 406 | // sign or long s; sb must be s, S, k, or K. |
| 407 | tr, size := utf8.DecodeRune(t) |
| 408 | switch sb { |
| 409 | case 's', 'S': |
| 410 | if tr != smallLongEss { |
| 411 | return false |
| 412 | } |
| 413 | case 'k', 'K': |
| 414 | if tr != kelvin { |
| 415 | return false |
| 416 | } |
| 417 | default: |
| 418 | return false |
| 419 | } |
| 420 | t = t[size:] |
| 421 | |
| 422 | } |
| 423 | if len(t) > 0 { |
| 424 | return false |
| 425 | } |
| 426 | return true |
| 427 | } |
| 428 | |
| 429 | // asciiEqualFold is a specialization of bytes.EqualFold for use when |
| 430 | // s is all ASCII (but may contain non-letters) and contains no |
| 431 | // special-folding letters. |
| 432 | // See comments on foldFunc. |
| 433 | func asciiEqualFold(s, t []byte) bool { |
| 434 | if len(s) != len(t) { |
| 435 | return false |
| 436 | } |
| 437 | for i, sb := range s { |
| 438 | tb := t[i] |
| 439 | if sb == tb { |
| 440 | continue |
| 441 | } |
| 442 | if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') { |
| 443 | if sb&caseMask != tb&caseMask { |
| 444 | return false |
| 445 | } |
| 446 | } else { |
| 447 | return false |
| 448 | } |
| 449 | } |
| 450 | return true |
| 451 | } |
| 452 | |
| 453 | // simpleLetterEqualFold is a specialization of bytes.EqualFold for |
| 454 | // use when s is all ASCII letters (no underscores, etc) and also |
| 455 | // doesn't contain 'k', 'K', 's', or 'S'. |
| 456 | // See comments on foldFunc. |
| 457 | func simpleLetterEqualFold(s, t []byte) bool { |
| 458 | if len(s) != len(t) { |
| 459 | return false |
| 460 | } |
| 461 | for i, b := range s { |
| 462 | if b&caseMask != t[i]&caseMask { |
| 463 | return false |
| 464 | } |
| 465 | } |
| 466 | return true |
| 467 | } |
| 468 | |
| 469 | // tagOptions is the string following a comma in a struct field's "json" |
| 470 | // tag, or the empty string. It does not include the leading comma. |
| 471 | type tagOptions string |
| 472 | |
| 473 | // parseTag splits a struct field's json tag into its name and |
| 474 | // comma-separated options. |
| 475 | func parseTag(tag string) (string, tagOptions) { |
| 476 | if idx := strings.Index(tag, ","); idx != -1 { |
| 477 | return tag[:idx], tagOptions(tag[idx+1:]) |
| 478 | } |
| 479 | return tag, tagOptions("") |
| 480 | } |
| 481 | |
| 482 | // Contains reports whether a comma-separated list of options |
| 483 | // contains a particular substr flag. substr must be surrounded by a |
| 484 | // string boundary or commas. |
| 485 | func (o tagOptions) Contains(optionName string) bool { |
| 486 | if len(o) == 0 { |
| 487 | return false |
| 488 | } |
| 489 | s := string(o) |
| 490 | for s != "" { |
| 491 | var next string |
| 492 | i := strings.Index(s, ",") |
| 493 | if i >= 0 { |
| 494 | s, next = s[:i], s[i+1:] |
| 495 | } |
| 496 | if s == optionName { |
| 497 | return true |
| 498 | } |
| 499 | s = next |
| 500 | } |
| 501 | return false |
| 502 | } |