blob: c903d4b24da94fb495c124143e057fb0ba0d02fa [file] [log] [blame]
Joey Armstrong903c69d2024-02-01 19:46:39 -05001package internal
2
3import (
4 "reflect"
5
6 "github.com/golang/protobuf/proto"
7)
8
9var typeOfBytes = reflect.TypeOf([]byte(nil))
10
11// GetUnrecognized fetches the bytes of unrecognized fields for the given message.
12func GetUnrecognized(msg proto.Message) []byte {
13 val := reflect.Indirect(reflect.ValueOf(msg))
14 u := val.FieldByName("XXX_unrecognized")
15 if u.IsValid() && u.Type() == typeOfBytes {
16 return u.Interface().([]byte)
17 }
18
19 // Fallback to reflection for API v2 messages
20 get, _, _, ok := unrecognizedGetSetMethods(val)
21 if !ok {
22 return nil
23 }
24
25 return get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte)
26}
27
28// SetUnrecognized adds the given bytes to the unrecognized fields for the given message.
29func SetUnrecognized(msg proto.Message, data []byte) {
30 val := reflect.Indirect(reflect.ValueOf(msg))
31 u := val.FieldByName("XXX_unrecognized")
32 if u.IsValid() && u.Type() == typeOfBytes {
33 // Just store the bytes in the unrecognized field
34 ub := u.Interface().([]byte)
35 ub = append(ub, data...)
36 u.Set(reflect.ValueOf(ub))
37 return
38 }
39
40 // Fallback to reflection for API v2 messages
41 get, set, argType, ok := unrecognizedGetSetMethods(val)
42 if !ok {
43 return
44 }
45
46 existing := get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte)
47 if len(existing) > 0 {
48 data = append(existing, data...)
49 }
50 set.Call([]reflect.Value{reflect.ValueOf(data).Convert(argType)})
51}
52
53func unrecognizedGetSetMethods(val reflect.Value) (get reflect.Value, set reflect.Value, argType reflect.Type, ok bool) {
54 // val could be an APIv2 message. We use reflection to interact with
55 // this message so that we don't have a hard dependency on the new
56 // version of the protobuf package.
57 refMethod := val.MethodByName("ProtoReflect")
58 if !refMethod.IsValid() {
59 if val.CanAddr() {
60 refMethod = val.Addr().MethodByName("ProtoReflect")
61 }
62 if !refMethod.IsValid() {
63 return
64 }
65 }
66 refType := refMethod.Type()
67 if refType.NumIn() != 0 || refType.NumOut() != 1 {
68 return
69 }
70 ref := refMethod.Call([]reflect.Value(nil))
71 getMethod, setMethod := ref[0].MethodByName("GetUnknown"), ref[0].MethodByName("SetUnknown")
72 if !getMethod.IsValid() || !setMethod.IsValid() {
73 return
74 }
75 getType := getMethod.Type()
76 setType := setMethod.Type()
77 if getType.NumIn() != 0 || getType.NumOut() != 1 || setType.NumIn() != 1 || setType.NumOut() != 0 {
78 return
79 }
80 arg := setType.In(0)
81 if !arg.ConvertibleTo(typeOfBytes) || getType.Out(0) != arg {
82 return
83 }
84
85 return getMethod, setMethod, arg, true
86}