blob: b82341e575cb334ed4381504b2b516509201ccac [file] [log] [blame]
khenaidoo26721882021-08-11 17:42:52 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "sync"
9
10 "google.golang.org/protobuf/internal/errors"
11 pref "google.golang.org/protobuf/reflect/protoreflect"
12 piface "google.golang.org/protobuf/runtime/protoiface"
13)
14
15func (mi *MessageInfo) checkInitialized(in piface.CheckInitializedInput) (piface.CheckInitializedOutput, error) {
16 var p pointer
17 if ms, ok := in.Message.(*messageState); ok {
18 p = ms.pointer()
19 } else {
20 p = in.Message.(*messageReflectWrapper).pointer()
21 }
22 return piface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
23}
24
25func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
26 mi.init()
27 if !mi.needsInitCheck {
28 return nil
29 }
30 if p.IsNil() {
31 for _, f := range mi.orderedCoderFields {
32 if f.isRequired {
33 return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
34 }
35 }
36 return nil
37 }
38 if mi.extensionOffset.IsValid() {
39 e := p.Apply(mi.extensionOffset).Extensions()
40 if err := mi.isInitExtensions(e); err != nil {
41 return err
42 }
43 }
44 for _, f := range mi.orderedCoderFields {
45 if !f.isRequired && f.funcs.isInit == nil {
46 continue
47 }
48 fptr := p.Apply(f.offset)
49 if f.isPointer && fptr.Elem().IsNil() {
50 if f.isRequired {
51 return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
52 }
53 continue
54 }
55 if f.funcs.isInit == nil {
56 continue
57 }
58 if err := f.funcs.isInit(fptr, f); err != nil {
59 return err
60 }
61 }
62 return nil
63}
64
65func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
66 if ext == nil {
67 return nil
68 }
69 for _, x := range *ext {
70 ei := getExtensionFieldInfo(x.Type())
71 if ei.funcs.isInit == nil {
72 continue
73 }
74 v := x.Value()
75 if !v.IsValid() {
76 continue
77 }
78 if err := ei.funcs.isInit(v); err != nil {
79 return err
80 }
81 }
82 return nil
83}
84
85var (
86 needsInitCheckMu sync.Mutex
87 needsInitCheckMap sync.Map
88)
89
90// needsInitCheck reports whether a message needs to be checked for partial initialization.
91//
92// It returns true if the message transitively includes any required or extension fields.
93func needsInitCheck(md pref.MessageDescriptor) bool {
94 if v, ok := needsInitCheckMap.Load(md); ok {
95 if has, ok := v.(bool); ok {
96 return has
97 }
98 }
99 needsInitCheckMu.Lock()
100 defer needsInitCheckMu.Unlock()
101 return needsInitCheckLocked(md)
102}
103
104func needsInitCheckLocked(md pref.MessageDescriptor) (has bool) {
105 if v, ok := needsInitCheckMap.Load(md); ok {
106 // If has is true, we've previously determined that this message
107 // needs init checks.
108 //
109 // If has is false, we've previously determined that it can never
110 // be uninitialized.
111 //
112 // If has is not a bool, we've just encountered a cycle in the
113 // message graph. In this case, it is safe to return false: If
114 // the message does have required fields, we'll detect them later
115 // in the graph traversal.
116 has, ok := v.(bool)
117 return ok && has
118 }
119 needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
120 defer func() {
121 needsInitCheckMap.Store(md, has)
122 }()
123 if md.RequiredNumbers().Len() > 0 {
124 return true
125 }
126 if md.ExtensionRanges().Len() > 0 {
127 return true
128 }
129 for i := 0; i < md.Fields().Len(); i++ {
130 fd := md.Fields().Get(i)
131 // Map keys are never messages, so just consider the map value.
132 if fd.IsMap() {
133 fd = fd.MapValue()
134 }
135 fmd := fd.Message()
136 if fmd != nil && needsInitCheckLocked(fmd) {
137 return true
138 }
139 }
140 return false
141}