blob: 7b2995dde5ebee0d348a9b98c27ae96362a7249f [file] [log] [blame]
khenaidoo7d3c5582021-08-11 18:09:44 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/encoding/protowire"
15 "google.golang.org/protobuf/internal/encoding/messageset"
16 "google.golang.org/protobuf/internal/flags"
17 "google.golang.org/protobuf/internal/genid"
18 "google.golang.org/protobuf/internal/strs"
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053019 "google.golang.org/protobuf/reflect/protoreflect"
20 "google.golang.org/protobuf/reflect/protoregistry"
21 "google.golang.org/protobuf/runtime/protoiface"
khenaidoo7d3c5582021-08-11 18:09:44 -040022)
23
24// ValidationStatus is the result of validating the wire-format encoding of a message.
25type ValidationStatus int
26
27const (
28 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
29 // The validator was unable to render a judgement.
30 //
31 // The only causes of this status are an aberrant message type appearing somewhere
32 // in the message or a failure in the extension resolver.
33 ValidationUnknown ValidationStatus = iota + 1
34
35 // ValidationInvalid indicates that unmarshaling the message will fail.
36 ValidationInvalid
37
38 // ValidationValid indicates that unmarshaling the message will succeed.
39 ValidationValid
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053040
41 // ValidationWrongWireType indicates that a validated field does not have
42 // the expected wire type.
43 ValidationWrongWireType
khenaidoo7d3c5582021-08-11 18:09:44 -040044)
45
46func (v ValidationStatus) String() string {
47 switch v {
48 case ValidationUnknown:
49 return "ValidationUnknown"
50 case ValidationInvalid:
51 return "ValidationInvalid"
52 case ValidationValid:
53 return "ValidationValid"
54 default:
55 return fmt.Sprintf("ValidationStatus(%d)", int(v))
56 }
57}
58
59// Validate determines whether the contents of the buffer are a valid wire encoding
60// of the message type.
61//
62// This function is exposed for testing.
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053063func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
khenaidoo7d3c5582021-08-11 18:09:44 -040064 mi, ok := mt.(*MessageInfo)
65 if !ok {
66 return out, ValidationUnknown
67 }
68 if in.Resolver == nil {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053069 in.Resolver = protoregistry.GlobalTypes
khenaidoo7d3c5582021-08-11 18:09:44 -040070 }
71 o, st := mi.validate(in.Buf, 0, unmarshalOptions{
72 flags: in.Flags,
73 resolver: in.Resolver,
74 })
75 if o.initialized {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +053076 out.Flags |= protoiface.UnmarshalInitialized
khenaidoo7d3c5582021-08-11 18:09:44 -040077 }
78 return out, st
79}
80
81type validationInfo struct {
82 mi *MessageInfo
83 typ validationType
84 keyType, valType validationType
85
86 // For non-required fields, requiredBit is 0.
87 //
88 // For required fields, requiredBit's nth bit is set, where n is a
89 // unique index in the range [0, MessageInfo.numRequiredFields).
90 //
91 // If there are more than 64 required fields, requiredBit is 0.
92 requiredBit uint64
93}
94
95type validationType uint8
96
97const (
98 validationTypeOther validationType = iota
99 validationTypeMessage
100 validationTypeGroup
101 validationTypeMap
102 validationTypeRepeatedVarint
103 validationTypeRepeatedFixed32
104 validationTypeRepeatedFixed64
105 validationTypeVarint
106 validationTypeFixed32
107 validationTypeFixed64
108 validationTypeBytes
109 validationTypeUTF8String
110 validationTypeMessageSetItem
111)
112
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530113func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
khenaidoo7d3c5582021-08-11 18:09:44 -0400114 var vi validationInfo
115 switch {
116 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
117 switch fd.Kind() {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530118 case protoreflect.MessageKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400119 vi.typ = validationTypeMessage
120 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
121 vi.mi = getMessageInfo(ot.Field(0).Type)
122 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530123 case protoreflect.GroupKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400124 vi.typ = validationTypeGroup
125 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
126 vi.mi = getMessageInfo(ot.Field(0).Type)
127 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530128 case protoreflect.StringKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400129 if strs.EnforceUTF8(fd) {
130 vi.typ = validationTypeUTF8String
131 }
132 }
133 default:
134 vi = newValidationInfo(fd, ft)
135 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530136 if fd.Cardinality() == protoreflect.Required {
khenaidoo7d3c5582021-08-11 18:09:44 -0400137 // Avoid overflow. The required field check is done with a 64-bit mask, with
138 // any message containing more than 64 required fields always reported as
139 // potentially uninitialized, so it is not important to get a precise count
140 // of the required fields past 64.
141 if mi.numRequiredFields < math.MaxUint8 {
142 mi.numRequiredFields++
143 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
144 }
145 }
146 return vi
147}
148
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530149func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
khenaidoo7d3c5582021-08-11 18:09:44 -0400150 var vi validationInfo
151 switch {
152 case fd.IsList():
153 switch fd.Kind() {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530154 case protoreflect.MessageKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400155 vi.typ = validationTypeMessage
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530156
157 if ft.Kind() == reflect.Ptr {
158 // Repeated opaque message fields are *[]*T.
159 ft = ft.Elem()
160 }
161
khenaidoo7d3c5582021-08-11 18:09:44 -0400162 if ft.Kind() == reflect.Slice {
163 vi.mi = getMessageInfo(ft.Elem())
164 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530165 case protoreflect.GroupKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400166 vi.typ = validationTypeGroup
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530167
168 if ft.Kind() == reflect.Ptr {
169 // Repeated opaque message fields are *[]*T.
170 ft = ft.Elem()
171 }
172
khenaidoo7d3c5582021-08-11 18:09:44 -0400173 if ft.Kind() == reflect.Slice {
174 vi.mi = getMessageInfo(ft.Elem())
175 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530176 case protoreflect.StringKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400177 vi.typ = validationTypeBytes
178 if strs.EnforceUTF8(fd) {
179 vi.typ = validationTypeUTF8String
180 }
181 default:
182 switch wireTypes[fd.Kind()] {
183 case protowire.VarintType:
184 vi.typ = validationTypeRepeatedVarint
185 case protowire.Fixed32Type:
186 vi.typ = validationTypeRepeatedFixed32
187 case protowire.Fixed64Type:
188 vi.typ = validationTypeRepeatedFixed64
189 }
190 }
191 case fd.IsMap():
192 vi.typ = validationTypeMap
193 switch fd.MapKey().Kind() {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530194 case protoreflect.StringKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400195 if strs.EnforceUTF8(fd) {
196 vi.keyType = validationTypeUTF8String
197 }
198 }
199 switch fd.MapValue().Kind() {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530200 case protoreflect.MessageKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400201 vi.valType = validationTypeMessage
202 if ft.Kind() == reflect.Map {
203 vi.mi = getMessageInfo(ft.Elem())
204 }
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530205 case protoreflect.StringKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400206 if strs.EnforceUTF8(fd) {
207 vi.valType = validationTypeUTF8String
208 }
209 }
210 default:
211 switch fd.Kind() {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530212 case protoreflect.MessageKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400213 vi.typ = validationTypeMessage
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530214 vi.mi = getMessageInfo(ft)
215 case protoreflect.GroupKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400216 vi.typ = validationTypeGroup
217 vi.mi = getMessageInfo(ft)
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530218 case protoreflect.StringKind:
khenaidoo7d3c5582021-08-11 18:09:44 -0400219 vi.typ = validationTypeBytes
220 if strs.EnforceUTF8(fd) {
221 vi.typ = validationTypeUTF8String
222 }
223 default:
224 switch wireTypes[fd.Kind()] {
225 case protowire.VarintType:
226 vi.typ = validationTypeVarint
227 case protowire.Fixed32Type:
228 vi.typ = validationTypeFixed32
229 case protowire.Fixed64Type:
230 vi.typ = validationTypeFixed64
231 case protowire.BytesType:
232 vi.typ = validationTypeBytes
233 }
234 }
235 }
236 return vi
237}
238
239func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
240 mi.init()
241 type validationState struct {
242 typ validationType
243 keyType, valType validationType
244 endGroup protowire.Number
245 mi *MessageInfo
246 tail []byte
247 requiredMask uint64
248 }
249
250 // Pre-allocate some slots to avoid repeated slice reallocation.
251 states := make([]validationState, 0, 16)
252 states = append(states, validationState{
253 typ: validationTypeMessage,
254 mi: mi,
255 })
256 if groupTag > 0 {
257 states[0].typ = validationTypeGroup
258 states[0].endGroup = groupTag
259 }
260 initialized := true
261 start := len(b)
262State:
263 for len(states) > 0 {
264 st := &states[len(states)-1]
265 for len(b) > 0 {
266 // Parse the tag (field number and wire type).
267 var tag uint64
268 if b[0] < 0x80 {
269 tag = uint64(b[0])
270 b = b[1:]
271 } else if len(b) >= 2 && b[1] < 128 {
272 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
273 b = b[2:]
274 } else {
275 var n int
276 tag, n = protowire.ConsumeVarint(b)
277 if n < 0 {
278 return out, ValidationInvalid
279 }
280 b = b[n:]
281 }
282 var num protowire.Number
283 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
284 return out, ValidationInvalid
285 } else {
286 num = protowire.Number(n)
287 }
288 wtyp := protowire.Type(tag & 7)
289
290 if wtyp == protowire.EndGroupType {
291 if st.endGroup == num {
292 goto PopState
293 }
294 return out, ValidationInvalid
295 }
296 var vi validationInfo
297 switch {
298 case st.typ == validationTypeMap:
299 switch num {
300 case genid.MapEntry_Key_field_number:
301 vi.typ = st.keyType
302 case genid.MapEntry_Value_field_number:
303 vi.typ = st.valType
304 vi.mi = st.mi
305 vi.requiredBit = 1
306 }
307 case flags.ProtoLegacy && st.mi.isMessageSet:
308 switch num {
309 case messageset.FieldItem:
310 vi.typ = validationTypeMessageSetItem
311 }
312 default:
313 var f *coderFieldInfo
314 if int(num) < len(st.mi.denseCoderFields) {
315 f = st.mi.denseCoderFields[num]
316 } else {
317 f = st.mi.coderFields[num]
318 }
319 if f != nil {
320 vi = f.validation
khenaidoo7d3c5582021-08-11 18:09:44 -0400321 break
322 }
323 // Possible extension field.
324 //
325 // TODO: We should return ValidationUnknown when:
326 // 1. The resolver is not frozen. (More extensions may be added to it.)
327 // 2. The resolver returns preg.NotFound.
328 // In this case, a type added to the resolver in the future could cause
329 // unmarshaling to begin failing. Supporting this requires some way to
330 // determine if the resolver is frozen.
331 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530332 if err != nil && err != protoregistry.NotFound {
khenaidoo7d3c5582021-08-11 18:09:44 -0400333 return out, ValidationUnknown
334 }
335 if err == nil {
336 vi = getExtensionFieldInfo(xt).validation
337 }
338 }
339 if vi.requiredBit != 0 {
340 // Check that the field has a compatible wire type.
341 // We only need to consider non-repeated field types,
342 // since repeated fields (and maps) can never be required.
343 ok := false
344 switch vi.typ {
345 case validationTypeVarint:
346 ok = wtyp == protowire.VarintType
347 case validationTypeFixed32:
348 ok = wtyp == protowire.Fixed32Type
349 case validationTypeFixed64:
350 ok = wtyp == protowire.Fixed64Type
351 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
352 ok = wtyp == protowire.BytesType
353 case validationTypeGroup:
354 ok = wtyp == protowire.StartGroupType
355 }
356 if ok {
357 st.requiredMask |= vi.requiredBit
358 }
359 }
360
361 switch wtyp {
362 case protowire.VarintType:
363 if len(b) >= 10 {
364 switch {
365 case b[0] < 0x80:
366 b = b[1:]
367 case b[1] < 0x80:
368 b = b[2:]
369 case b[2] < 0x80:
370 b = b[3:]
371 case b[3] < 0x80:
372 b = b[4:]
373 case b[4] < 0x80:
374 b = b[5:]
375 case b[5] < 0x80:
376 b = b[6:]
377 case b[6] < 0x80:
378 b = b[7:]
379 case b[7] < 0x80:
380 b = b[8:]
381 case b[8] < 0x80:
382 b = b[9:]
383 case b[9] < 0x80 && b[9] < 2:
384 b = b[10:]
385 default:
386 return out, ValidationInvalid
387 }
388 } else {
389 switch {
390 case len(b) > 0 && b[0] < 0x80:
391 b = b[1:]
392 case len(b) > 1 && b[1] < 0x80:
393 b = b[2:]
394 case len(b) > 2 && b[2] < 0x80:
395 b = b[3:]
396 case len(b) > 3 && b[3] < 0x80:
397 b = b[4:]
398 case len(b) > 4 && b[4] < 0x80:
399 b = b[5:]
400 case len(b) > 5 && b[5] < 0x80:
401 b = b[6:]
402 case len(b) > 6 && b[6] < 0x80:
403 b = b[7:]
404 case len(b) > 7 && b[7] < 0x80:
405 b = b[8:]
406 case len(b) > 8 && b[8] < 0x80:
407 b = b[9:]
408 case len(b) > 9 && b[9] < 2:
409 b = b[10:]
410 default:
411 return out, ValidationInvalid
412 }
413 }
414 continue State
415 case protowire.BytesType:
416 var size uint64
417 if len(b) >= 1 && b[0] < 0x80 {
418 size = uint64(b[0])
419 b = b[1:]
420 } else if len(b) >= 2 && b[1] < 128 {
421 size = uint64(b[0]&0x7f) + uint64(b[1])<<7
422 b = b[2:]
423 } else {
424 var n int
425 size, n = protowire.ConsumeVarint(b)
426 if n < 0 {
427 return out, ValidationInvalid
428 }
429 b = b[n:]
430 }
431 if size > uint64(len(b)) {
432 return out, ValidationInvalid
433 }
434 v := b[:size]
435 b = b[size:]
436 switch vi.typ {
437 case validationTypeMessage:
438 if vi.mi == nil {
439 return out, ValidationUnknown
440 }
441 vi.mi.init()
442 fallthrough
443 case validationTypeMap:
444 if vi.mi != nil {
445 vi.mi.init()
446 }
447 states = append(states, validationState{
448 typ: vi.typ,
449 keyType: vi.keyType,
450 valType: vi.valType,
451 mi: vi.mi,
452 tail: b,
453 })
454 b = v
455 continue State
456 case validationTypeRepeatedVarint:
457 // Packed field.
458 for len(v) > 0 {
459 _, n := protowire.ConsumeVarint(v)
460 if n < 0 {
461 return out, ValidationInvalid
462 }
463 v = v[n:]
464 }
465 case validationTypeRepeatedFixed32:
466 // Packed field.
467 if len(v)%4 != 0 {
468 return out, ValidationInvalid
469 }
470 case validationTypeRepeatedFixed64:
471 // Packed field.
472 if len(v)%8 != 0 {
473 return out, ValidationInvalid
474 }
475 case validationTypeUTF8String:
476 if !utf8.Valid(v) {
477 return out, ValidationInvalid
478 }
479 }
480 case protowire.Fixed32Type:
481 if len(b) < 4 {
482 return out, ValidationInvalid
483 }
484 b = b[4:]
485 case protowire.Fixed64Type:
486 if len(b) < 8 {
487 return out, ValidationInvalid
488 }
489 b = b[8:]
490 case protowire.StartGroupType:
491 switch {
492 case vi.typ == validationTypeGroup:
493 if vi.mi == nil {
494 return out, ValidationUnknown
495 }
496 vi.mi.init()
497 states = append(states, validationState{
498 typ: validationTypeGroup,
499 mi: vi.mi,
500 endGroup: num,
501 })
502 continue State
503 case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
504 typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
505 if err != nil {
506 return out, ValidationInvalid
507 }
508 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
509 switch {
Akash Reddy Kankanala92dfdf82025-03-23 22:07:09 +0530510 case err == protoregistry.NotFound:
khenaidoo7d3c5582021-08-11 18:09:44 -0400511 b = b[n:]
512 case err != nil:
513 return out, ValidationUnknown
514 default:
515 xvi := getExtensionFieldInfo(xt).validation
516 if xvi.mi != nil {
517 xvi.mi.init()
518 }
519 states = append(states, validationState{
520 typ: xvi.typ,
521 mi: xvi.mi,
522 tail: b[n:],
523 })
524 b = v
525 continue State
526 }
527 default:
528 n := protowire.ConsumeFieldValue(num, wtyp, b)
529 if n < 0 {
530 return out, ValidationInvalid
531 }
532 b = b[n:]
533 }
534 default:
535 return out, ValidationInvalid
536 }
537 }
538 if st.endGroup != 0 {
539 return out, ValidationInvalid
540 }
541 if len(b) != 0 {
542 return out, ValidationInvalid
543 }
544 b = st.tail
545 PopState:
546 numRequiredFields := 0
547 switch st.typ {
548 case validationTypeMessage, validationTypeGroup:
549 numRequiredFields = int(st.mi.numRequiredFields)
550 case validationTypeMap:
551 // If this is a map field with a message value that contains
552 // required fields, require that the value be present.
553 if st.mi != nil && st.mi.numRequiredFields > 0 {
554 numRequiredFields = 1
555 }
556 }
557 // If there are more than 64 required fields, this check will
558 // always fail and we will report that the message is potentially
559 // uninitialized.
560 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
561 initialized = false
562 }
563 states = states[:len(states)-1]
564 }
565 out.n = start - len(b)
566 if initialized {
567 out.initialized = true
568 }
569 return out, ValidationValid
570}