blob: 3e9a6a2f66c77cdcb18acb3bb56ffa36be9e5a18 [file] [log] [blame]
David K. Bainbridgebd6b2882021-08-26 13:31:02 +00001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/internal/errors"
9 "google.golang.org/protobuf/reflect/protoreflect"
10 "google.golang.org/protobuf/runtime/protoiface"
11)
12
13// CheckInitialized returns an error if any required fields in m are not set.
14func CheckInitialized(m Message) error {
15 // Treat a nil message interface as an "untyped" empty message,
16 // which we assume to have no required fields.
17 if m == nil {
18 return nil
19 }
20
21 return checkInitialized(m.ProtoReflect())
22}
23
24// CheckInitialized returns an error if any required fields in m are not set.
25func checkInitialized(m protoreflect.Message) error {
26 if methods := protoMethods(m); methods != nil && methods.CheckInitialized != nil {
27 _, err := methods.CheckInitialized(protoiface.CheckInitializedInput{
28 Message: m,
29 })
30 return err
31 }
32 return checkInitializedSlow(m)
33}
34
35func checkInitializedSlow(m protoreflect.Message) error {
36 md := m.Descriptor()
37 fds := md.Fields()
38 for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
39 fd := fds.ByNumber(nums.Get(i))
40 if !m.Has(fd) {
41 return errors.RequiredNotSet(string(fd.FullName()))
42 }
43 }
44 var err error
45 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
46 switch {
47 case fd.IsList():
48 if fd.Message() == nil {
49 return true
50 }
51 for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
52 err = checkInitialized(list.Get(i).Message())
53 }
54 case fd.IsMap():
55 if fd.MapValue().Message() == nil {
56 return true
57 }
58 v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool {
59 err = checkInitialized(v.Message())
60 return err == nil
61 })
62 default:
63 if fd.Message() == nil {
64 return true
65 }
66 err = checkInitialized(v.Message())
67 }
68 return err == nil
69 })
70 return err
71}