blob: a31134eeb3b7dc49285ea1215f6e643d72eccb5d [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "bytes"
9 "encoding"
10 "fmt"
11 "io"
12 "math"
13 "sort"
14 "strings"
15
16 "google.golang.org/protobuf/encoding/prototext"
17 "google.golang.org/protobuf/encoding/protowire"
18 "google.golang.org/protobuf/proto"
19 "google.golang.org/protobuf/reflect/protoreflect"
20 "google.golang.org/protobuf/reflect/protoregistry"
21)
22
23const wrapTextMarshalV2 = false
24
25// TextMarshaler is a configurable text format marshaler.
26type TextMarshaler struct {
27 Compact bool // use compact text format (one line)
28 ExpandAny bool // expand google.protobuf.Any messages of known types
29}
30
31// Marshal writes the proto text format of m to w.
32func (tm *TextMarshaler) Marshal(w io.Writer, m Message) error {
33 b, err := tm.marshal(m)
34 if len(b) > 0 {
35 if _, err := w.Write(b); err != nil {
36 return err
37 }
38 }
39 return err
40}
41
42// Text returns a proto text formatted string of m.
43func (tm *TextMarshaler) Text(m Message) string {
44 b, _ := tm.marshal(m)
45 return string(b)
46}
47
48func (tm *TextMarshaler) marshal(m Message) ([]byte, error) {
49 mr := MessageReflect(m)
50 if mr == nil || !mr.IsValid() {
51 return []byte("<nil>"), nil
52 }
53
54 if wrapTextMarshalV2 {
55 if m, ok := m.(encoding.TextMarshaler); ok {
56 return m.MarshalText()
57 }
58
59 opts := prototext.MarshalOptions{
60 AllowPartial: true,
61 EmitUnknown: true,
62 }
63 if !tm.Compact {
64 opts.Indent = " "
65 }
66 if !tm.ExpandAny {
67 opts.Resolver = (*protoregistry.Types)(nil)
68 }
69 return opts.Marshal(mr.Interface())
70 } else {
71 w := &textWriter{
72 compact: tm.Compact,
73 expandAny: tm.ExpandAny,
74 complete: true,
75 }
76
77 if m, ok := m.(encoding.TextMarshaler); ok {
78 b, err := m.MarshalText()
79 if err != nil {
80 return nil, err
81 }
82 w.Write(b)
83 return w.buf, nil
84 }
85
86 err := w.writeMessage(mr)
87 return w.buf, err
88 }
89}
90
91var (
92 defaultTextMarshaler = TextMarshaler{}
93 compactTextMarshaler = TextMarshaler{Compact: true}
94)
95
96// MarshalText writes the proto text format of m to w.
97func MarshalText(w io.Writer, m Message) error { return defaultTextMarshaler.Marshal(w, m) }
98
99// MarshalTextString returns a proto text formatted string of m.
100func MarshalTextString(m Message) string { return defaultTextMarshaler.Text(m) }
101
102// CompactText writes the compact proto text format of m to w.
103func CompactText(w io.Writer, m Message) error { return compactTextMarshaler.Marshal(w, m) }
104
105// CompactTextString returns a compact proto text formatted string of m.
106func CompactTextString(m Message) string { return compactTextMarshaler.Text(m) }
107
108var (
109 newline = []byte("\n")
110 endBraceNewline = []byte("}\n")
111 posInf = []byte("inf")
112 negInf = []byte("-inf")
113 nan = []byte("nan")
114)
115
116// textWriter is an io.Writer that tracks its indentation level.
117type textWriter struct {
118 compact bool // same as TextMarshaler.Compact
119 expandAny bool // same as TextMarshaler.ExpandAny
120 complete bool // whether the current position is a complete line
121 indent int // indentation level; never negative
122 buf []byte
123}
124
125func (w *textWriter) Write(p []byte) (n int, _ error) {
126 newlines := bytes.Count(p, newline)
127 if newlines == 0 {
128 if !w.compact && w.complete {
129 w.writeIndent()
130 }
131 w.buf = append(w.buf, p...)
132 w.complete = false
133 return len(p), nil
134 }
135
136 frags := bytes.SplitN(p, newline, newlines+1)
137 if w.compact {
138 for i, frag := range frags {
139 if i > 0 {
140 w.buf = append(w.buf, ' ')
141 n++
142 }
143 w.buf = append(w.buf, frag...)
144 n += len(frag)
145 }
146 return n, nil
147 }
148
149 for i, frag := range frags {
150 if w.complete {
151 w.writeIndent()
152 }
153 w.buf = append(w.buf, frag...)
154 n += len(frag)
155 if i+1 < len(frags) {
156 w.buf = append(w.buf, '\n')
157 n++
158 }
159 }
160 w.complete = len(frags[len(frags)-1]) == 0
161 return n, nil
162}
163
164func (w *textWriter) WriteByte(c byte) error {
165 if w.compact && c == '\n' {
166 c = ' '
167 }
168 if !w.compact && w.complete {
169 w.writeIndent()
170 }
171 w.buf = append(w.buf, c)
172 w.complete = c == '\n'
173 return nil
174}
175
176func (w *textWriter) writeName(fd protoreflect.FieldDescriptor) {
177 if !w.compact && w.complete {
178 w.writeIndent()
179 }
180 w.complete = false
181
182 if fd.Kind() != protoreflect.GroupKind {
183 w.buf = append(w.buf, fd.Name()...)
184 w.WriteByte(':')
185 } else {
186 // Use message type name for group field name.
187 w.buf = append(w.buf, fd.Message().Name()...)
188 }
189
190 if !w.compact {
191 w.WriteByte(' ')
192 }
193}
194
195func requiresQuotes(u string) bool {
196 // When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
197 for _, ch := range u {
198 switch {
199 case ch == '.' || ch == '/' || ch == '_':
200 continue
201 case '0' <= ch && ch <= '9':
202 continue
203 case 'A' <= ch && ch <= 'Z':
204 continue
205 case 'a' <= ch && ch <= 'z':
206 continue
207 default:
208 return true
209 }
210 }
211 return false
212}
213
214// writeProto3Any writes an expanded google.protobuf.Any message.
215//
216// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
217// required messages are not linked in).
218//
219// It returns (true, error) when sv was written in expanded format or an error
220// was encountered.
221func (w *textWriter) writeProto3Any(m protoreflect.Message) (bool, error) {
222 md := m.Descriptor()
223 fdURL := md.Fields().ByName("type_url")
224 fdVal := md.Fields().ByName("value")
225
226 url := m.Get(fdURL).String()
227 mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
228 if err != nil {
229 return false, nil
230 }
231
232 b := m.Get(fdVal).Bytes()
233 m2 := mt.New()
234 if err := proto.Unmarshal(b, m2.Interface()); err != nil {
235 return false, nil
236 }
237 w.Write([]byte("["))
238 if requiresQuotes(url) {
239 w.writeQuotedString(url)
240 } else {
241 w.Write([]byte(url))
242 }
243 if w.compact {
244 w.Write([]byte("]:<"))
245 } else {
246 w.Write([]byte("]: <\n"))
247 w.indent++
248 }
249 if err := w.writeMessage(m2); err != nil {
250 return true, err
251 }
252 if w.compact {
253 w.Write([]byte("> "))
254 } else {
255 w.indent--
256 w.Write([]byte(">\n"))
257 }
258 return true, nil
259}
260
261func (w *textWriter) writeMessage(m protoreflect.Message) error {
262 md := m.Descriptor()
263 if w.expandAny && md.FullName() == "google.protobuf.Any" {
264 if canExpand, err := w.writeProto3Any(m); canExpand {
265 return err
266 }
267 }
268
269 fds := md.Fields()
270 for i := 0; i < fds.Len(); {
271 fd := fds.Get(i)
272 if od := fd.ContainingOneof(); od != nil {
273 fd = m.WhichOneof(od)
274 i += od.Fields().Len()
275 } else {
276 i++
277 }
278 if fd == nil || !m.Has(fd) {
279 continue
280 }
281
282 switch {
283 case fd.IsList():
284 lv := m.Get(fd).List()
285 for j := 0; j < lv.Len(); j++ {
286 w.writeName(fd)
287 v := lv.Get(j)
288 if err := w.writeSingularValue(v, fd); err != nil {
289 return err
290 }
291 w.WriteByte('\n')
292 }
293 case fd.IsMap():
294 kfd := fd.MapKey()
295 vfd := fd.MapValue()
296 mv := m.Get(fd).Map()
297
298 type entry struct{ key, val protoreflect.Value }
299 var entries []entry
300 mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
301 entries = append(entries, entry{k.Value(), v})
302 return true
303 })
304 sort.Slice(entries, func(i, j int) bool {
305 switch kfd.Kind() {
306 case protoreflect.BoolKind:
307 return !entries[i].key.Bool() && entries[j].key.Bool()
308 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
309 return entries[i].key.Int() < entries[j].key.Int()
310 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
311 return entries[i].key.Uint() < entries[j].key.Uint()
312 case protoreflect.StringKind:
313 return entries[i].key.String() < entries[j].key.String()
314 default:
315 panic("invalid kind")
316 }
317 })
318 for _, entry := range entries {
319 w.writeName(fd)
320 w.WriteByte('<')
321 if !w.compact {
322 w.WriteByte('\n')
323 }
324 w.indent++
325 w.writeName(kfd)
326 if err := w.writeSingularValue(entry.key, kfd); err != nil {
327 return err
328 }
329 w.WriteByte('\n')
330 w.writeName(vfd)
331 if err := w.writeSingularValue(entry.val, vfd); err != nil {
332 return err
333 }
334 w.WriteByte('\n')
335 w.indent--
336 w.WriteByte('>')
337 w.WriteByte('\n')
338 }
339 default:
340 w.writeName(fd)
341 if err := w.writeSingularValue(m.Get(fd), fd); err != nil {
342 return err
343 }
344 w.WriteByte('\n')
345 }
346 }
347
348 if b := m.GetUnknown(); len(b) > 0 {
349 w.writeUnknownFields(b)
350 }
351 return w.writeExtensions(m)
352}
353
354func (w *textWriter) writeSingularValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
355 switch fd.Kind() {
356 case protoreflect.FloatKind, protoreflect.DoubleKind:
357 switch vf := v.Float(); {
358 case math.IsInf(vf, +1):
359 w.Write(posInf)
360 case math.IsInf(vf, -1):
361 w.Write(negInf)
362 case math.IsNaN(vf):
363 w.Write(nan)
364 default:
365 fmt.Fprint(w, v.Interface())
366 }
367 case protoreflect.StringKind:
368 // NOTE: This does not validate UTF-8 for historical reasons.
369 w.writeQuotedString(string(v.String()))
370 case protoreflect.BytesKind:
371 w.writeQuotedString(string(v.Bytes()))
372 case protoreflect.MessageKind, protoreflect.GroupKind:
373 var bra, ket byte = '<', '>'
374 if fd.Kind() == protoreflect.GroupKind {
375 bra, ket = '{', '}'
376 }
377 w.WriteByte(bra)
378 if !w.compact {
379 w.WriteByte('\n')
380 }
381 w.indent++
382 m := v.Message()
383 if m2, ok := m.Interface().(encoding.TextMarshaler); ok {
384 b, err := m2.MarshalText()
385 if err != nil {
386 return err
387 }
388 w.Write(b)
389 } else {
390 w.writeMessage(m)
391 }
392 w.indent--
393 w.WriteByte(ket)
394 case protoreflect.EnumKind:
395 if ev := fd.Enum().Values().ByNumber(v.Enum()); ev != nil {
396 fmt.Fprint(w, ev.Name())
397 } else {
398 fmt.Fprint(w, v.Enum())
399 }
400 default:
401 fmt.Fprint(w, v.Interface())
402 }
403 return nil
404}
405
406// writeQuotedString writes a quoted string in the protocol buffer text format.
407func (w *textWriter) writeQuotedString(s string) {
408 w.WriteByte('"')
409 for i := 0; i < len(s); i++ {
410 switch c := s[i]; c {
411 case '\n':
412 w.buf = append(w.buf, `\n`...)
413 case '\r':
414 w.buf = append(w.buf, `\r`...)
415 case '\t':
416 w.buf = append(w.buf, `\t`...)
417 case '"':
418 w.buf = append(w.buf, `\"`...)
419 case '\\':
420 w.buf = append(w.buf, `\\`...)
421 default:
422 if isPrint := c >= 0x20 && c < 0x7f; isPrint {
423 w.buf = append(w.buf, c)
424 } else {
425 w.buf = append(w.buf, fmt.Sprintf(`\%03o`, c)...)
426 }
427 }
428 }
429 w.WriteByte('"')
430}
431
432func (w *textWriter) writeUnknownFields(b []byte) {
433 if !w.compact {
434 fmt.Fprintf(w, "/* %d unknown bytes */\n", len(b))
435 }
436
437 for len(b) > 0 {
438 num, wtyp, n := protowire.ConsumeTag(b)
439 if n < 0 {
440 return
441 }
442 b = b[n:]
443
444 if wtyp == protowire.EndGroupType {
445 w.indent--
446 w.Write(endBraceNewline)
447 continue
448 }
449 fmt.Fprint(w, num)
450 if wtyp != protowire.StartGroupType {
451 w.WriteByte(':')
452 }
453 if !w.compact || wtyp == protowire.StartGroupType {
454 w.WriteByte(' ')
455 }
456 switch wtyp {
457 case protowire.VarintType:
458 v, n := protowire.ConsumeVarint(b)
459 if n < 0 {
460 return
461 }
462 b = b[n:]
463 fmt.Fprint(w, v)
464 case protowire.Fixed32Type:
465 v, n := protowire.ConsumeFixed32(b)
466 if n < 0 {
467 return
468 }
469 b = b[n:]
470 fmt.Fprint(w, v)
471 case protowire.Fixed64Type:
472 v, n := protowire.ConsumeFixed64(b)
473 if n < 0 {
474 return
475 }
476 b = b[n:]
477 fmt.Fprint(w, v)
478 case protowire.BytesType:
479 v, n := protowire.ConsumeBytes(b)
480 if n < 0 {
481 return
482 }
483 b = b[n:]
484 fmt.Fprintf(w, "%q", v)
485 case protowire.StartGroupType:
486 w.WriteByte('{')
487 w.indent++
488 default:
489 fmt.Fprintf(w, "/* unknown wire type %d */", wtyp)
490 }
491 w.WriteByte('\n')
492 }
493}
494
495// writeExtensions writes all the extensions in m.
496func (w *textWriter) writeExtensions(m protoreflect.Message) error {
497 md := m.Descriptor()
498 if md.ExtensionRanges().Len() == 0 {
499 return nil
500 }
501
502 type ext struct {
503 desc protoreflect.FieldDescriptor
504 val protoreflect.Value
505 }
506 var exts []ext
507 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
508 if fd.IsExtension() {
509 exts = append(exts, ext{fd, v})
510 }
511 return true
512 })
513 sort.Slice(exts, func(i, j int) bool {
514 return exts[i].desc.Number() < exts[j].desc.Number()
515 })
516
517 for _, ext := range exts {
518 // For message set, use the name of the message as the extension name.
519 name := string(ext.desc.FullName())
520 if isMessageSet(ext.desc.ContainingMessage()) {
521 name = strings.TrimSuffix(name, ".message_set_extension")
522 }
523
524 if !ext.desc.IsList() {
525 if err := w.writeSingularExtension(name, ext.val, ext.desc); err != nil {
526 return err
527 }
528 } else {
529 lv := ext.val.List()
530 for i := 0; i < lv.Len(); i++ {
531 if err := w.writeSingularExtension(name, lv.Get(i), ext.desc); err != nil {
532 return err
533 }
534 }
535 }
536 }
537 return nil
538}
539
540func (w *textWriter) writeSingularExtension(name string, v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
541 fmt.Fprintf(w, "[%s]:", name)
542 if !w.compact {
543 w.WriteByte(' ')
544 }
545 if err := w.writeSingularValue(v, fd); err != nil {
546 return err
547 }
548 w.WriteByte('\n')
549 return nil
550}
551
552func (w *textWriter) writeIndent() {
553 if !w.complete {
554 return
555 }
556 for i := 0; i < w.indent*2; i++ {
557 w.buf = append(w.buf, ' ')
558 }
559 w.complete = false
560}