blob: ee0207e461e0ea9b879e7025a6fa1931ad3a678d [file] [log] [blame]
Matteo Scandoloa6a3aee2019-11-26 13:30:14 -07001package runtime
2
3import (
4 "encoding/base64"
5 "fmt"
6 "net/url"
7 "reflect"
8 "regexp"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/golang/protobuf/proto"
14 "github.com/grpc-ecosystem/grpc-gateway/utilities"
15 "google.golang.org/grpc/grpclog"
16)
17
18var valuesKeyRegexp = regexp.MustCompile("^(.*)\\[(.*)\\]$")
19
20// PopulateQueryParameters populates "values" into "msg".
21// A value is ignored if its key starts with one of the elements in "filter".
22func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
23 for key, values := range values {
24 match := valuesKeyRegexp.FindStringSubmatch(key)
25 if len(match) == 3 {
26 key = match[1]
27 values = append([]string{match[2]}, values...)
28 }
29 fieldPath := strings.Split(key, ".")
30 if filter.HasCommonPrefix(fieldPath) {
31 continue
32 }
33 if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
34 return err
35 }
36 }
37 return nil
38}
39
40// PopulateFieldFromPath sets a value in a nested Protobuf structure.
41// It instantiates missing protobuf fields as it goes.
42func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
43 fieldPath := strings.Split(fieldPathString, ".")
44 return populateFieldValueFromPath(msg, fieldPath, []string{value})
45}
46
47func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
48 m := reflect.ValueOf(msg)
49 if m.Kind() != reflect.Ptr {
50 return fmt.Errorf("unexpected type %T: %v", msg, msg)
51 }
52 var props *proto.Properties
53 m = m.Elem()
54 for i, fieldName := range fieldPath {
55 isLast := i == len(fieldPath)-1
56 if !isLast && m.Kind() != reflect.Struct {
57 return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
58 }
59 var f reflect.Value
60 var err error
61 f, props, err = fieldByProtoName(m, fieldName)
62 if err != nil {
63 return err
64 } else if !f.IsValid() {
65 grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
66 return nil
67 }
68
69 switch f.Kind() {
70 case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
71 if !isLast {
72 return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
73 }
74 m = f
75 case reflect.Slice:
76 if !isLast {
77 return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
78 }
79 // Handle []byte
80 if f.Type().Elem().Kind() == reflect.Uint8 {
81 m = f
82 break
83 }
84 return populateRepeatedField(f, values, props)
85 case reflect.Ptr:
86 if f.IsNil() {
87 m = reflect.New(f.Type().Elem())
88 f.Set(m.Convert(f.Type()))
89 }
90 m = f.Elem()
91 continue
92 case reflect.Struct:
93 m = f
94 continue
95 case reflect.Map:
96 if !isLast {
97 return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
98 }
99 return populateMapField(f, values, props)
100 default:
101 return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
102 }
103 }
104 switch len(values) {
105 case 0:
106 return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
107 case 1:
108 default:
109 grpclog.Infof("too many field values: %s", strings.Join(fieldPath, "."))
110 }
111 return populateField(m, values[0], props)
112}
113
114// fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
115// "m" must be a struct value. It returns zero reflect.Value if no such field found.
116func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
117 props := proto.GetProperties(m.Type())
118
119 // look up field name in oneof map
120 if op, ok := props.OneofTypes[name]; ok {
121 v := reflect.New(op.Type.Elem())
122 field := m.Field(op.Field)
123 if !field.IsNil() {
124 return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
125 }
126 field.Set(v)
127 return v.Elem().Field(0), op.Prop, nil
128 }
129
130 for _, p := range props.Prop {
131 if p.OrigName == name {
132 return m.FieldByName(p.Name), p, nil
133 }
134 if p.JSONName == name {
135 return m.FieldByName(p.Name), p, nil
136 }
137 }
138 return reflect.Value{}, nil, nil
139}
140
141func populateMapField(f reflect.Value, values []string, props *proto.Properties) error {
142 if len(values) != 2 {
143 return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name)
144 }
145
146 key, value := values[0], values[1]
147 keyType := f.Type().Key()
148 valueType := f.Type().Elem()
149 if f.IsNil() {
150 f.Set(reflect.MakeMap(f.Type()))
151 }
152
153 keyConv, ok := convFromType[keyType.Kind()]
154 if !ok {
155 return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name)
156 }
157 valueConv, ok := convFromType[valueType.Kind()]
158 if !ok {
159 return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name)
160 }
161
162 keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)})
163 if err := keyV[1].Interface(); err != nil {
164 return err.(error)
165 }
166 valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)})
167 if err := valueV[1].Interface(); err != nil {
168 return err.(error)
169 }
170
171 f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType))
172
173 return nil
174}
175
176func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
177 elemType := f.Type().Elem()
178
179 // is the destination field a slice of an enumeration type?
180 if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
181 return populateFieldEnumRepeated(f, values, enumValMap)
182 }
183
184 conv, ok := convFromType[elemType.Kind()]
185 if !ok {
186 return fmt.Errorf("unsupported field type %s", elemType)
187 }
188 f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
189 for i, v := range values {
190 result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
191 if err := result[1].Interface(); err != nil {
192 return err.(error)
193 }
194 f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
195 }
196 return nil
197}
198
199func populateField(f reflect.Value, value string, props *proto.Properties) error {
200 i := f.Addr().Interface()
201
202 // Handle protobuf well known types
203 var name string
204 switch m := i.(type) {
205 case interface{ XXX_WellKnownType() string }:
206 name = m.XXX_WellKnownType()
207 case proto.Message:
208 const wktPrefix = "google.protobuf."
209 if fullName := proto.MessageName(m); strings.HasPrefix(fullName, wktPrefix) {
210 name = fullName[len(wktPrefix):]
211 }
212 }
213 switch name {
214 case "Timestamp":
215 if value == "null" {
216 f.FieldByName("Seconds").SetInt(0)
217 f.FieldByName("Nanos").SetInt(0)
218 return nil
219 }
220
221 t, err := time.Parse(time.RFC3339Nano, value)
222 if err != nil {
223 return fmt.Errorf("bad Timestamp: %v", err)
224 }
225 f.FieldByName("Seconds").SetInt(int64(t.Unix()))
226 f.FieldByName("Nanos").SetInt(int64(t.Nanosecond()))
227 return nil
228 case "Duration":
229 if value == "null" {
230 f.FieldByName("Seconds").SetInt(0)
231 f.FieldByName("Nanos").SetInt(0)
232 return nil
233 }
234 d, err := time.ParseDuration(value)
235 if err != nil {
236 return fmt.Errorf("bad Duration: %v", err)
237 }
238
239 ns := d.Nanoseconds()
240 s := ns / 1e9
241 ns %= 1e9
242 f.FieldByName("Seconds").SetInt(s)
243 f.FieldByName("Nanos").SetInt(ns)
244 return nil
245 case "DoubleValue":
246 fallthrough
247 case "FloatValue":
248 float64Val, err := strconv.ParseFloat(value, 64)
249 if err != nil {
250 return fmt.Errorf("bad DoubleValue: %s", value)
251 }
252 f.FieldByName("Value").SetFloat(float64Val)
253 return nil
254 case "Int64Value":
255 fallthrough
256 case "Int32Value":
257 int64Val, err := strconv.ParseInt(value, 10, 64)
258 if err != nil {
259 return fmt.Errorf("bad DoubleValue: %s", value)
260 }
261 f.FieldByName("Value").SetInt(int64Val)
262 return nil
263 case "UInt64Value":
264 fallthrough
265 case "UInt32Value":
266 uint64Val, err := strconv.ParseUint(value, 10, 64)
267 if err != nil {
268 return fmt.Errorf("bad DoubleValue: %s", value)
269 }
270 f.FieldByName("Value").SetUint(uint64Val)
271 return nil
272 case "BoolValue":
273 if value == "true" {
274 f.FieldByName("Value").SetBool(true)
275 } else if value == "false" {
276 f.FieldByName("Value").SetBool(false)
277 } else {
278 return fmt.Errorf("bad BoolValue: %s", value)
279 }
280 return nil
281 case "StringValue":
282 f.FieldByName("Value").SetString(value)
283 return nil
284 case "BytesValue":
285 bytesVal, err := base64.StdEncoding.DecodeString(value)
286 if err != nil {
287 return fmt.Errorf("bad BytesValue: %s", value)
288 }
289 f.FieldByName("Value").SetBytes(bytesVal)
290 return nil
291 case "FieldMask":
292 p := f.FieldByName("Paths")
293 for _, v := range strings.Split(value, ",") {
294 if v != "" {
295 p.Set(reflect.Append(p, reflect.ValueOf(v)))
296 }
297 }
298 return nil
299 }
300
301 // Handle Time and Duration stdlib types
302 switch t := i.(type) {
303 case *time.Time:
304 pt, err := time.Parse(time.RFC3339Nano, value)
305 if err != nil {
306 return fmt.Errorf("bad Timestamp: %v", err)
307 }
308 *t = pt
309 return nil
310 case *time.Duration:
311 d, err := time.ParseDuration(value)
312 if err != nil {
313 return fmt.Errorf("bad Duration: %v", err)
314 }
315 *t = d
316 return nil
317 }
318
319 // is the destination field an enumeration type?
320 if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
321 return populateFieldEnum(f, value, enumValMap)
322 }
323
324 conv, ok := convFromType[f.Kind()]
325 if !ok {
326 return fmt.Errorf("field type %T is not supported in query parameters", i)
327 }
328 result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
329 if err := result[1].Interface(); err != nil {
330 return err.(error)
331 }
332 f.Set(result[0].Convert(f.Type()))
333 return nil
334}
335
336func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
337 // see if it's an enumeration string
338 if enumVal, ok := enumValMap[value]; ok {
339 return reflect.ValueOf(enumVal).Convert(t), nil
340 }
341
342 // check for an integer that matches an enumeration value
343 eVal, err := strconv.Atoi(value)
344 if err != nil {
345 return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
346 }
347 for _, v := range enumValMap {
348 if v == int32(eVal) {
349 return reflect.ValueOf(eVal).Convert(t), nil
350 }
351 }
352 return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
353}
354
355func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
356 cval, err := convertEnum(value, f.Type(), enumValMap)
357 if err != nil {
358 return err
359 }
360 f.Set(cval)
361 return nil
362}
363
364func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
365 elemType := f.Type().Elem()
366 f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
367 for i, v := range values {
368 result, err := convertEnum(v, elemType, enumValMap)
369 if err != nil {
370 return err
371 }
372 f.Index(i).Set(result)
373 }
374 return nil
375}
376
377var (
378 convFromType = map[reflect.Kind]reflect.Value{
379 reflect.String: reflect.ValueOf(String),
380 reflect.Bool: reflect.ValueOf(Bool),
381 reflect.Float64: reflect.ValueOf(Float64),
382 reflect.Float32: reflect.ValueOf(Float32),
383 reflect.Int64: reflect.ValueOf(Int64),
384 reflect.Int32: reflect.ValueOf(Int32),
385 reflect.Uint64: reflect.ValueOf(Uint64),
386 reflect.Uint32: reflect.ValueOf(Uint32),
387 reflect.Slice: reflect.ValueOf(Bytes),
388 }
389)