khenaidoo | 106c61a | 2021-08-11 18:05:46 -0400 | [diff] [blame^] | 1 | // Copyright 2020 The Go Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style |
| 3 | // license that can be found in the LICENSE file. |
| 4 | |
| 5 | // Package order provides ordered access to messages and maps. |
| 6 | package order |
| 7 | |
| 8 | import ( |
| 9 | "sort" |
| 10 | "sync" |
| 11 | |
| 12 | pref "google.golang.org/protobuf/reflect/protoreflect" |
| 13 | ) |
| 14 | |
| 15 | type messageField struct { |
| 16 | fd pref.FieldDescriptor |
| 17 | v pref.Value |
| 18 | } |
| 19 | |
| 20 | var messageFieldPool = sync.Pool{ |
| 21 | New: func() interface{} { return new([]messageField) }, |
| 22 | } |
| 23 | |
| 24 | type ( |
| 25 | // FieldRnger is an interface for visiting all fields in a message. |
| 26 | // The protoreflect.Message type implements this interface. |
| 27 | FieldRanger interface{ Range(VisitField) } |
| 28 | // VisitField is called everytime a message field is visited. |
| 29 | VisitField = func(pref.FieldDescriptor, pref.Value) bool |
| 30 | ) |
| 31 | |
| 32 | // RangeFields iterates over the fields of fs according to the specified order. |
| 33 | func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) { |
| 34 | if less == nil { |
| 35 | fs.Range(fn) |
| 36 | return |
| 37 | } |
| 38 | |
| 39 | // Obtain a pre-allocated scratch buffer. |
| 40 | p := messageFieldPool.Get().(*[]messageField) |
| 41 | fields := (*p)[:0] |
| 42 | defer func() { |
| 43 | if cap(fields) < 1024 { |
| 44 | *p = fields |
| 45 | messageFieldPool.Put(p) |
| 46 | } |
| 47 | }() |
| 48 | |
| 49 | // Collect all fields in the message and sort them. |
| 50 | fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool { |
| 51 | fields = append(fields, messageField{fd, v}) |
| 52 | return true |
| 53 | }) |
| 54 | sort.Slice(fields, func(i, j int) bool { |
| 55 | return less(fields[i].fd, fields[j].fd) |
| 56 | }) |
| 57 | |
| 58 | // Visit the fields in the specified ordering. |
| 59 | for _, f := range fields { |
| 60 | if !fn(f.fd, f.v) { |
| 61 | return |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | type mapEntry struct { |
| 67 | k pref.MapKey |
| 68 | v pref.Value |
| 69 | } |
| 70 | |
| 71 | var mapEntryPool = sync.Pool{ |
| 72 | New: func() interface{} { return new([]mapEntry) }, |
| 73 | } |
| 74 | |
| 75 | type ( |
| 76 | // EntryRanger is an interface for visiting all fields in a message. |
| 77 | // The protoreflect.Map type implements this interface. |
| 78 | EntryRanger interface{ Range(VisitEntry) } |
| 79 | // VisitEntry is called everytime a map entry is visited. |
| 80 | VisitEntry = func(pref.MapKey, pref.Value) bool |
| 81 | ) |
| 82 | |
| 83 | // RangeEntries iterates over the entries of es according to the specified order. |
| 84 | func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) { |
| 85 | if less == nil { |
| 86 | es.Range(fn) |
| 87 | return |
| 88 | } |
| 89 | |
| 90 | // Obtain a pre-allocated scratch buffer. |
| 91 | p := mapEntryPool.Get().(*[]mapEntry) |
| 92 | entries := (*p)[:0] |
| 93 | defer func() { |
| 94 | if cap(entries) < 1024 { |
| 95 | *p = entries |
| 96 | mapEntryPool.Put(p) |
| 97 | } |
| 98 | }() |
| 99 | |
| 100 | // Collect all entries in the map and sort them. |
| 101 | es.Range(func(k pref.MapKey, v pref.Value) bool { |
| 102 | entries = append(entries, mapEntry{k, v}) |
| 103 | return true |
| 104 | }) |
| 105 | sort.Slice(entries, func(i, j int) bool { |
| 106 | return less(entries[i].k, entries[j].k) |
| 107 | }) |
| 108 | |
| 109 | // Visit the entries in the specified ordering. |
| 110 | for _, e := range entries { |
| 111 | if !fn(e.k, e.v) { |
| 112 | return |
| 113 | } |
| 114 | } |
| 115 | } |