blob: 5fbba5e8e8b5f5d5a69c7b92b84ea9b9357823c0 [file] [log] [blame]
khenaidoo59ce9dd2019-11-11 13:05:32 -05001package runtime
2
3import (
4 "encoding/base64"
5 "fmt"
6 "net/url"
7 "reflect"
8 "regexp"
9 "strconv"
10 "strings"
11 "time"
12
13 "github.com/golang/protobuf/proto"
14 "github.com/grpc-ecosystem/grpc-gateway/utilities"
15 "google.golang.org/grpc/grpclog"
16)
17
18// PopulateQueryParameters populates "values" into "msg".
19// A value is ignored if its key starts with one of the elements in "filter".
20func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
21 for key, values := range values {
22 re, err := regexp.Compile("^(.*)\\[(.*)\\]$")
23 if err != nil {
24 return err
25 }
26 match := re.FindStringSubmatch(key)
27 if len(match) == 3 {
28 key = match[1]
29 values = append([]string{match[2]}, values...)
30 }
31 fieldPath := strings.Split(key, ".")
32 if filter.HasCommonPrefix(fieldPath) {
33 continue
34 }
35 if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
36 return err
37 }
38 }
39 return nil
40}
41
42// PopulateFieldFromPath sets a value in a nested Protobuf structure.
43// It instantiates missing protobuf fields as it goes.
44func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
45 fieldPath := strings.Split(fieldPathString, ".")
46 return populateFieldValueFromPath(msg, fieldPath, []string{value})
47}
48
49func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
50 m := reflect.ValueOf(msg)
51 if m.Kind() != reflect.Ptr {
52 return fmt.Errorf("unexpected type %T: %v", msg, msg)
53 }
54 var props *proto.Properties
55 m = m.Elem()
56 for i, fieldName := range fieldPath {
57 isLast := i == len(fieldPath)-1
58 if !isLast && m.Kind() != reflect.Struct {
59 return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
60 }
61 var f reflect.Value
62 var err error
63 f, props, err = fieldByProtoName(m, fieldName)
64 if err != nil {
65 return err
66 } else if !f.IsValid() {
67 grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
68 return nil
69 }
70
71 switch f.Kind() {
72 case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
73 if !isLast {
74 return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
75 }
76 m = f
77 case reflect.Slice:
78 if !isLast {
79 return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
80 }
81 // Handle []byte
82 if f.Type().Elem().Kind() == reflect.Uint8 {
83 m = f
84 break
85 }
86 return populateRepeatedField(f, values, props)
87 case reflect.Ptr:
88 if f.IsNil() {
89 m = reflect.New(f.Type().Elem())
90 f.Set(m.Convert(f.Type()))
91 }
92 m = f.Elem()
93 continue
94 case reflect.Struct:
95 m = f
96 continue
97 case reflect.Map:
98 if !isLast {
99 return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
100 }
101 return populateMapField(f, values, props)
102 default:
103 return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
104 }
105 }
106 switch len(values) {
107 case 0:
108 return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
109 case 1:
110 default:
111 grpclog.Infof("too many field values: %s", strings.Join(fieldPath, "."))
112 }
113 return populateField(m, values[0], props)
114}
115
116// fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
117// "m" must be a struct value. It returns zero reflect.Value if no such field found.
118func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
119 props := proto.GetProperties(m.Type())
120
121 // look up field name in oneof map
122 if op, ok := props.OneofTypes[name]; ok {
123 v := reflect.New(op.Type.Elem())
124 field := m.Field(op.Field)
125 if !field.IsNil() {
126 return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
127 }
128 field.Set(v)
129 return v.Elem().Field(0), op.Prop, nil
130 }
131
132 for _, p := range props.Prop {
133 if p.OrigName == name {
134 return m.FieldByName(p.Name), p, nil
135 }
136 if p.JSONName == name {
137 return m.FieldByName(p.Name), p, nil
138 }
139 }
140 return reflect.Value{}, nil, nil
141}
142
143func populateMapField(f reflect.Value, values []string, props *proto.Properties) error {
144 if len(values) != 2 {
145 return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name)
146 }
147
148 key, value := values[0], values[1]
149 keyType := f.Type().Key()
150 valueType := f.Type().Elem()
151 if f.IsNil() {
152 f.Set(reflect.MakeMap(f.Type()))
153 }
154
155 keyConv, ok := convFromType[keyType.Kind()]
156 if !ok {
157 return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name)
158 }
159 valueConv, ok := convFromType[valueType.Kind()]
160 if !ok {
161 return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name)
162 }
163
164 keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)})
165 if err := keyV[1].Interface(); err != nil {
166 return err.(error)
167 }
168 valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)})
169 if err := valueV[1].Interface(); err != nil {
170 return err.(error)
171 }
172
173 f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType))
174
175 return nil
176}
177
178func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
179 elemType := f.Type().Elem()
180
181 // is the destination field a slice of an enumeration type?
182 if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
183 return populateFieldEnumRepeated(f, values, enumValMap)
184 }
185
186 conv, ok := convFromType[elemType.Kind()]
187 if !ok {
188 return fmt.Errorf("unsupported field type %s", elemType)
189 }
190 f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
191 for i, v := range values {
192 result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
193 if err := result[1].Interface(); err != nil {
194 return err.(error)
195 }
196 f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
197 }
198 return nil
199}
200
201func populateField(f reflect.Value, value string, props *proto.Properties) error {
202 i := f.Addr().Interface()
203
204 // Handle protobuf well known types
205 var name string
206 switch m := i.(type) {
207 case interface{ XXX_WellKnownType() string }:
208 name = m.XXX_WellKnownType()
209 case proto.Message:
210 const wktPrefix = "google.protobuf."
211 if fullName := proto.MessageName(m); strings.HasPrefix(fullName, wktPrefix) {
212 name = fullName[len(wktPrefix):]
213 }
214 }
215 switch name {
216 case "Timestamp":
217 if value == "null" {
218 f.FieldByName("Seconds").SetInt(0)
219 f.FieldByName("Nanos").SetInt(0)
220 return nil
221 }
222
223 t, err := time.Parse(time.RFC3339Nano, value)
224 if err != nil {
225 return fmt.Errorf("bad Timestamp: %v", err)
226 }
227 f.FieldByName("Seconds").SetInt(int64(t.Unix()))
228 f.FieldByName("Nanos").SetInt(int64(t.Nanosecond()))
229 return nil
230 case "Duration":
231 if value == "null" {
232 f.FieldByName("Seconds").SetInt(0)
233 f.FieldByName("Nanos").SetInt(0)
234 return nil
235 }
236 d, err := time.ParseDuration(value)
237 if err != nil {
238 return fmt.Errorf("bad Duration: %v", err)
239 }
240
241 ns := d.Nanoseconds()
242 s := ns / 1e9
243 ns %= 1e9
244 f.FieldByName("Seconds").SetInt(s)
245 f.FieldByName("Nanos").SetInt(ns)
246 return nil
247 case "DoubleValue":
248 fallthrough
249 case "FloatValue":
250 float64Val, err := strconv.ParseFloat(value, 64)
251 if err != nil {
252 return fmt.Errorf("bad DoubleValue: %s", value)
253 }
254 f.FieldByName("Value").SetFloat(float64Val)
255 return nil
256 case "Int64Value":
257 fallthrough
258 case "Int32Value":
259 int64Val, err := strconv.ParseInt(value, 10, 64)
260 if err != nil {
261 return fmt.Errorf("bad DoubleValue: %s", value)
262 }
263 f.FieldByName("Value").SetInt(int64Val)
264 return nil
265 case "UInt64Value":
266 fallthrough
267 case "UInt32Value":
268 uint64Val, err := strconv.ParseUint(value, 10, 64)
269 if err != nil {
270 return fmt.Errorf("bad DoubleValue: %s", value)
271 }
272 f.FieldByName("Value").SetUint(uint64Val)
273 return nil
274 case "BoolValue":
275 if value == "true" {
276 f.FieldByName("Value").SetBool(true)
277 } else if value == "false" {
278 f.FieldByName("Value").SetBool(false)
279 } else {
280 return fmt.Errorf("bad BoolValue: %s", value)
281 }
282 return nil
283 case "StringValue":
284 f.FieldByName("Value").SetString(value)
285 return nil
286 case "BytesValue":
287 bytesVal, err := base64.StdEncoding.DecodeString(value)
288 if err != nil {
289 return fmt.Errorf("bad BytesValue: %s", value)
290 }
291 f.FieldByName("Value").SetBytes(bytesVal)
292 return nil
293 case "FieldMask":
294 p := f.FieldByName("Paths")
295 for _, v := range strings.Split(value, ",") {
296 if v != "" {
297 p.Set(reflect.Append(p, reflect.ValueOf(v)))
298 }
299 }
300 return nil
301 }
302
303 // Handle Time and Duration stdlib types
304 switch t := i.(type) {
305 case *time.Time:
306 pt, err := time.Parse(time.RFC3339Nano, value)
307 if err != nil {
308 return fmt.Errorf("bad Timestamp: %v", err)
309 }
310 *t = pt
311 return nil
312 case *time.Duration:
313 d, err := time.ParseDuration(value)
314 if err != nil {
315 return fmt.Errorf("bad Duration: %v", err)
316 }
317 *t = d
318 return nil
319 }
320
321 // is the destination field an enumeration type?
322 if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
323 return populateFieldEnum(f, value, enumValMap)
324 }
325
326 conv, ok := convFromType[f.Kind()]
327 if !ok {
328 return fmt.Errorf("field type %T is not supported in query parameters", i)
329 }
330 result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
331 if err := result[1].Interface(); err != nil {
332 return err.(error)
333 }
334 f.Set(result[0].Convert(f.Type()))
335 return nil
336}
337
338func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
339 // see if it's an enumeration string
340 if enumVal, ok := enumValMap[value]; ok {
341 return reflect.ValueOf(enumVal).Convert(t), nil
342 }
343
344 // check for an integer that matches an enumeration value
345 eVal, err := strconv.Atoi(value)
346 if err != nil {
347 return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
348 }
349 for _, v := range enumValMap {
350 if v == int32(eVal) {
351 return reflect.ValueOf(eVal).Convert(t), nil
352 }
353 }
354 return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
355}
356
357func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
358 cval, err := convertEnum(value, f.Type(), enumValMap)
359 if err != nil {
360 return err
361 }
362 f.Set(cval)
363 return nil
364}
365
366func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
367 elemType := f.Type().Elem()
368 f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
369 for i, v := range values {
370 result, err := convertEnum(v, elemType, enumValMap)
371 if err != nil {
372 return err
373 }
374 f.Index(i).Set(result)
375 }
376 return nil
377}
378
379var (
380 convFromType = map[reflect.Kind]reflect.Value{
381 reflect.String: reflect.ValueOf(String),
382 reflect.Bool: reflect.ValueOf(Bool),
383 reflect.Float64: reflect.ValueOf(Float64),
384 reflect.Float32: reflect.ValueOf(Float32),
385 reflect.Int64: reflect.ValueOf(Int64),
386 reflect.Int32: reflect.ValueOf(Int32),
387 reflect.Uint64: reflect.ValueOf(Uint64),
388 reflect.Uint32: reflect.ValueOf(Uint32),
389 reflect.Slice: reflect.ValueOf(Bytes),
390 }
391)