khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 1 | package runtime |
| 2 | |
| 3 | import ( |
| 4 | "encoding/base64" |
| 5 | "fmt" |
| 6 | "net/url" |
| 7 | "reflect" |
| 8 | "regexp" |
| 9 | "strconv" |
| 10 | "strings" |
| 11 | "time" |
| 12 | |
| 13 | "github.com/golang/protobuf/proto" |
| 14 | "github.com/grpc-ecosystem/grpc-gateway/utilities" |
| 15 | "google.golang.org/grpc/grpclog" |
| 16 | ) |
| 17 | |
khenaidoo | d948f77 | 2021-08-11 17:49:24 -0400 | [diff] [blame] | 18 | var valuesKeyRegexp = regexp.MustCompile("^(.*)\\[(.*)\\]$") |
| 19 | |
| 20 | var currentQueryParser QueryParameterParser = &defaultQueryParser{} |
| 21 | |
| 22 | // QueryParameterParser defines interface for all query parameter parsers |
| 23 | type QueryParameterParser interface { |
| 24 | Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error |
| 25 | } |
| 26 | |
| 27 | // PopulateQueryParameters parses query parameters |
| 28 | // into "msg" using current query parser |
khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 29 | func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { |
khenaidoo | d948f77 | 2021-08-11 17:49:24 -0400 | [diff] [blame] | 30 | return currentQueryParser.Parse(msg, values, filter) |
| 31 | } |
| 32 | |
| 33 | type defaultQueryParser struct{} |
| 34 | |
| 35 | // Parse populates "values" into "msg". |
| 36 | // A value is ignored if its key starts with one of the elements in "filter". |
| 37 | func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { |
khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 38 | for key, values := range values { |
khenaidoo | d948f77 | 2021-08-11 17:49:24 -0400 | [diff] [blame] | 39 | match := valuesKeyRegexp.FindStringSubmatch(key) |
khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 40 | if len(match) == 3 { |
| 41 | key = match[1] |
| 42 | values = append([]string{match[2]}, values...) |
| 43 | } |
| 44 | fieldPath := strings.Split(key, ".") |
| 45 | if filter.HasCommonPrefix(fieldPath) { |
| 46 | continue |
| 47 | } |
| 48 | if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil { |
| 49 | return err |
| 50 | } |
| 51 | } |
| 52 | return nil |
| 53 | } |
| 54 | |
| 55 | // PopulateFieldFromPath sets a value in a nested Protobuf structure. |
| 56 | // It instantiates missing protobuf fields as it goes. |
| 57 | func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error { |
| 58 | fieldPath := strings.Split(fieldPathString, ".") |
| 59 | return populateFieldValueFromPath(msg, fieldPath, []string{value}) |
| 60 | } |
| 61 | |
| 62 | func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error { |
| 63 | m := reflect.ValueOf(msg) |
| 64 | if m.Kind() != reflect.Ptr { |
| 65 | return fmt.Errorf("unexpected type %T: %v", msg, msg) |
| 66 | } |
| 67 | var props *proto.Properties |
| 68 | m = m.Elem() |
| 69 | for i, fieldName := range fieldPath { |
| 70 | isLast := i == len(fieldPath)-1 |
| 71 | if !isLast && m.Kind() != reflect.Struct { |
| 72 | return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, ".")) |
| 73 | } |
| 74 | var f reflect.Value |
| 75 | var err error |
| 76 | f, props, err = fieldByProtoName(m, fieldName) |
| 77 | if err != nil { |
| 78 | return err |
| 79 | } else if !f.IsValid() { |
| 80 | grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, ".")) |
| 81 | return nil |
| 82 | } |
| 83 | |
| 84 | switch f.Kind() { |
| 85 | case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64: |
| 86 | if !isLast { |
| 87 | return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], ".")) |
| 88 | } |
| 89 | m = f |
| 90 | case reflect.Slice: |
| 91 | if !isLast { |
| 92 | return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, ".")) |
| 93 | } |
| 94 | // Handle []byte |
| 95 | if f.Type().Elem().Kind() == reflect.Uint8 { |
| 96 | m = f |
| 97 | break |
| 98 | } |
| 99 | return populateRepeatedField(f, values, props) |
| 100 | case reflect.Ptr: |
| 101 | if f.IsNil() { |
| 102 | m = reflect.New(f.Type().Elem()) |
| 103 | f.Set(m.Convert(f.Type())) |
| 104 | } |
| 105 | m = f.Elem() |
| 106 | continue |
| 107 | case reflect.Struct: |
| 108 | m = f |
| 109 | continue |
| 110 | case reflect.Map: |
| 111 | if !isLast { |
| 112 | return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], ".")) |
| 113 | } |
| 114 | return populateMapField(f, values, props) |
| 115 | default: |
| 116 | return fmt.Errorf("unexpected type %s in %T", f.Type(), msg) |
| 117 | } |
| 118 | } |
| 119 | switch len(values) { |
| 120 | case 0: |
| 121 | return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, ".")) |
| 122 | case 1: |
| 123 | default: |
| 124 | grpclog.Infof("too many field values: %s", strings.Join(fieldPath, ".")) |
| 125 | } |
| 126 | return populateField(m, values[0], props) |
| 127 | } |
| 128 | |
| 129 | // fieldByProtoName looks up a field whose corresponding protobuf field name is "name". |
| 130 | // "m" must be a struct value. It returns zero reflect.Value if no such field found. |
| 131 | func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) { |
| 132 | props := proto.GetProperties(m.Type()) |
| 133 | |
| 134 | // look up field name in oneof map |
khenaidoo | d948f77 | 2021-08-11 17:49:24 -0400 | [diff] [blame] | 135 | for _, op := range props.OneofTypes { |
| 136 | if name == op.Prop.OrigName || name == op.Prop.JSONName { |
| 137 | v := reflect.New(op.Type.Elem()) |
| 138 | field := m.Field(op.Field) |
| 139 | if !field.IsNil() { |
| 140 | return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName) |
| 141 | } |
| 142 | field.Set(v) |
| 143 | return v.Elem().Field(0), op.Prop, nil |
khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 144 | } |
khenaidoo | ab1f7bd | 2019-11-14 14:00:27 -0500 | [diff] [blame] | 145 | } |
| 146 | |
| 147 | for _, p := range props.Prop { |
| 148 | if p.OrigName == name { |
| 149 | return m.FieldByName(p.Name), p, nil |
| 150 | } |
| 151 | if p.JSONName == name { |
| 152 | return m.FieldByName(p.Name), p, nil |
| 153 | } |
| 154 | } |
| 155 | return reflect.Value{}, nil, nil |
| 156 | } |
| 157 | |
| 158 | func populateMapField(f reflect.Value, values []string, props *proto.Properties) error { |
| 159 | if len(values) != 2 { |
| 160 | return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name) |
| 161 | } |
| 162 | |
| 163 | key, value := values[0], values[1] |
| 164 | keyType := f.Type().Key() |
| 165 | valueType := f.Type().Elem() |
| 166 | if f.IsNil() { |
| 167 | f.Set(reflect.MakeMap(f.Type())) |
| 168 | } |
| 169 | |
| 170 | keyConv, ok := convFromType[keyType.Kind()] |
| 171 | if !ok { |
| 172 | return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name) |
| 173 | } |
| 174 | valueConv, ok := convFromType[valueType.Kind()] |
| 175 | if !ok { |
| 176 | return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name) |
| 177 | } |
| 178 | |
| 179 | keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)}) |
| 180 | if err := keyV[1].Interface(); err != nil { |
| 181 | return err.(error) |
| 182 | } |
| 183 | valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)}) |
| 184 | if err := valueV[1].Interface(); err != nil { |
| 185 | return err.(error) |
| 186 | } |
| 187 | |
| 188 | f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType)) |
| 189 | |
| 190 | return nil |
| 191 | } |
| 192 | |
| 193 | func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error { |
| 194 | elemType := f.Type().Elem() |
| 195 | |
| 196 | // is the destination field a slice of an enumeration type? |
| 197 | if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil { |
| 198 | return populateFieldEnumRepeated(f, values, enumValMap) |
| 199 | } |
| 200 | |
| 201 | conv, ok := convFromType[elemType.Kind()] |
| 202 | if !ok { |
| 203 | return fmt.Errorf("unsupported field type %s", elemType) |
| 204 | } |
| 205 | f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type())) |
| 206 | for i, v := range values { |
| 207 | result := conv.Call([]reflect.Value{reflect.ValueOf(v)}) |
| 208 | if err := result[1].Interface(); err != nil { |
| 209 | return err.(error) |
| 210 | } |
| 211 | f.Index(i).Set(result[0].Convert(f.Index(i).Type())) |
| 212 | } |
| 213 | return nil |
| 214 | } |
| 215 | |
| 216 | func populateField(f reflect.Value, value string, props *proto.Properties) error { |
| 217 | i := f.Addr().Interface() |
| 218 | |
| 219 | // Handle protobuf well known types |
| 220 | var name string |
| 221 | switch m := i.(type) { |
| 222 | case interface{ XXX_WellKnownType() string }: |
| 223 | name = m.XXX_WellKnownType() |
| 224 | case proto.Message: |
| 225 | const wktPrefix = "google.protobuf." |
| 226 | if fullName := proto.MessageName(m); strings.HasPrefix(fullName, wktPrefix) { |
| 227 | name = fullName[len(wktPrefix):] |
| 228 | } |
| 229 | } |
| 230 | switch name { |
| 231 | case "Timestamp": |
| 232 | if value == "null" { |
| 233 | f.FieldByName("Seconds").SetInt(0) |
| 234 | f.FieldByName("Nanos").SetInt(0) |
| 235 | return nil |
| 236 | } |
| 237 | |
| 238 | t, err := time.Parse(time.RFC3339Nano, value) |
| 239 | if err != nil { |
| 240 | return fmt.Errorf("bad Timestamp: %v", err) |
| 241 | } |
| 242 | f.FieldByName("Seconds").SetInt(int64(t.Unix())) |
| 243 | f.FieldByName("Nanos").SetInt(int64(t.Nanosecond())) |
| 244 | return nil |
| 245 | case "Duration": |
| 246 | if value == "null" { |
| 247 | f.FieldByName("Seconds").SetInt(0) |
| 248 | f.FieldByName("Nanos").SetInt(0) |
| 249 | return nil |
| 250 | } |
| 251 | d, err := time.ParseDuration(value) |
| 252 | if err != nil { |
| 253 | return fmt.Errorf("bad Duration: %v", err) |
| 254 | } |
| 255 | |
| 256 | ns := d.Nanoseconds() |
| 257 | s := ns / 1e9 |
| 258 | ns %= 1e9 |
| 259 | f.FieldByName("Seconds").SetInt(s) |
| 260 | f.FieldByName("Nanos").SetInt(ns) |
| 261 | return nil |
| 262 | case "DoubleValue": |
| 263 | fallthrough |
| 264 | case "FloatValue": |
| 265 | float64Val, err := strconv.ParseFloat(value, 64) |
| 266 | if err != nil { |
| 267 | return fmt.Errorf("bad DoubleValue: %s", value) |
| 268 | } |
| 269 | f.FieldByName("Value").SetFloat(float64Val) |
| 270 | return nil |
| 271 | case "Int64Value": |
| 272 | fallthrough |
| 273 | case "Int32Value": |
| 274 | int64Val, err := strconv.ParseInt(value, 10, 64) |
| 275 | if err != nil { |
| 276 | return fmt.Errorf("bad DoubleValue: %s", value) |
| 277 | } |
| 278 | f.FieldByName("Value").SetInt(int64Val) |
| 279 | return nil |
| 280 | case "UInt64Value": |
| 281 | fallthrough |
| 282 | case "UInt32Value": |
| 283 | uint64Val, err := strconv.ParseUint(value, 10, 64) |
| 284 | if err != nil { |
| 285 | return fmt.Errorf("bad DoubleValue: %s", value) |
| 286 | } |
| 287 | f.FieldByName("Value").SetUint(uint64Val) |
| 288 | return nil |
| 289 | case "BoolValue": |
| 290 | if value == "true" { |
| 291 | f.FieldByName("Value").SetBool(true) |
| 292 | } else if value == "false" { |
| 293 | f.FieldByName("Value").SetBool(false) |
| 294 | } else { |
| 295 | return fmt.Errorf("bad BoolValue: %s", value) |
| 296 | } |
| 297 | return nil |
| 298 | case "StringValue": |
| 299 | f.FieldByName("Value").SetString(value) |
| 300 | return nil |
| 301 | case "BytesValue": |
| 302 | bytesVal, err := base64.StdEncoding.DecodeString(value) |
| 303 | if err != nil { |
| 304 | return fmt.Errorf("bad BytesValue: %s", value) |
| 305 | } |
| 306 | f.FieldByName("Value").SetBytes(bytesVal) |
| 307 | return nil |
| 308 | case "FieldMask": |
| 309 | p := f.FieldByName("Paths") |
| 310 | for _, v := range strings.Split(value, ",") { |
| 311 | if v != "" { |
| 312 | p.Set(reflect.Append(p, reflect.ValueOf(v))) |
| 313 | } |
| 314 | } |
| 315 | return nil |
| 316 | } |
| 317 | |
| 318 | // Handle Time and Duration stdlib types |
| 319 | switch t := i.(type) { |
| 320 | case *time.Time: |
| 321 | pt, err := time.Parse(time.RFC3339Nano, value) |
| 322 | if err != nil { |
| 323 | return fmt.Errorf("bad Timestamp: %v", err) |
| 324 | } |
| 325 | *t = pt |
| 326 | return nil |
| 327 | case *time.Duration: |
| 328 | d, err := time.ParseDuration(value) |
| 329 | if err != nil { |
| 330 | return fmt.Errorf("bad Duration: %v", err) |
| 331 | } |
| 332 | *t = d |
| 333 | return nil |
| 334 | } |
| 335 | |
| 336 | // is the destination field an enumeration type? |
| 337 | if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil { |
| 338 | return populateFieldEnum(f, value, enumValMap) |
| 339 | } |
| 340 | |
| 341 | conv, ok := convFromType[f.Kind()] |
| 342 | if !ok { |
| 343 | return fmt.Errorf("field type %T is not supported in query parameters", i) |
| 344 | } |
| 345 | result := conv.Call([]reflect.Value{reflect.ValueOf(value)}) |
| 346 | if err := result[1].Interface(); err != nil { |
| 347 | return err.(error) |
| 348 | } |
| 349 | f.Set(result[0].Convert(f.Type())) |
| 350 | return nil |
| 351 | } |
| 352 | |
| 353 | func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) { |
| 354 | // see if it's an enumeration string |
| 355 | if enumVal, ok := enumValMap[value]; ok { |
| 356 | return reflect.ValueOf(enumVal).Convert(t), nil |
| 357 | } |
| 358 | |
| 359 | // check for an integer that matches an enumeration value |
| 360 | eVal, err := strconv.Atoi(value) |
| 361 | if err != nil { |
| 362 | return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t) |
| 363 | } |
| 364 | for _, v := range enumValMap { |
| 365 | if v == int32(eVal) { |
| 366 | return reflect.ValueOf(eVal).Convert(t), nil |
| 367 | } |
| 368 | } |
| 369 | return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t) |
| 370 | } |
| 371 | |
| 372 | func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error { |
| 373 | cval, err := convertEnum(value, f.Type(), enumValMap) |
| 374 | if err != nil { |
| 375 | return err |
| 376 | } |
| 377 | f.Set(cval) |
| 378 | return nil |
| 379 | } |
| 380 | |
| 381 | func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error { |
| 382 | elemType := f.Type().Elem() |
| 383 | f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type())) |
| 384 | for i, v := range values { |
| 385 | result, err := convertEnum(v, elemType, enumValMap) |
| 386 | if err != nil { |
| 387 | return err |
| 388 | } |
| 389 | f.Index(i).Set(result) |
| 390 | } |
| 391 | return nil |
| 392 | } |
| 393 | |
| 394 | var ( |
| 395 | convFromType = map[reflect.Kind]reflect.Value{ |
| 396 | reflect.String: reflect.ValueOf(String), |
| 397 | reflect.Bool: reflect.ValueOf(Bool), |
| 398 | reflect.Float64: reflect.ValueOf(Float64), |
| 399 | reflect.Float32: reflect.ValueOf(Float32), |
| 400 | reflect.Int64: reflect.ValueOf(Int64), |
| 401 | reflect.Int32: reflect.ValueOf(Int32), |
| 402 | reflect.Uint64: reflect.ValueOf(Uint64), |
| 403 | reflect.Uint32: reflect.ValueOf(Uint32), |
| 404 | reflect.Slice: reflect.ValueOf(Bytes), |
| 405 | } |
| 406 | ) |