blob: 02f8a321fd90f179ece406e14369d0e1b8b8076f [file] [log] [blame]
package codec
import (
"errors"
"fmt"
"io"
"math"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc"
)
var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{}
var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{}
var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{}
func init() {
varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true
varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true
fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true
fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true
fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true
fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true
fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true
fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true
}
// ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
// it reads indicates an end-group marker.
var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
// MessageFactory is used to instantiate messages when DecodeFieldValue needs to
// decode a message value.
//
// Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
// implements this interface.
type MessageFactory interface {
NewMessage(md *desc.MessageDescriptor) proto.Message
}
// UnknownField represents a field that was parsed from the binary wire
// format for a message, but was not a recognized field number. Enough
// information is preserved so that re-serializing the message won't lose
// any of the unrecognized data.
type UnknownField struct {
// The tag number for the unrecognized field.
Tag int32
// Encoding indicates how the unknown field was encoded on the wire. If it
// is proto.WireBytes or proto.WireGroupStart then Contents will be set to
// the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
// significant 32 bits of Value. Otherwise, the data is in all 64 bits of
// Value.
Encoding int8
Contents []byte
Value uint64
}
// DecodeZigZag32 decodes a signed 32-bit integer from the given
// zig-zag encoded value.
func DecodeZigZag32(v uint64) int32 {
return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
}
// DecodeZigZag64 decodes a signed 64-bit integer from the given
// zig-zag encoded value.
func DecodeZigZag64(v uint64) int64 {
return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
}
// DecodeFieldValue will read a field value from the buffer and return its
// value and the corresponding field descriptor. The given function is used
// to lookup a field descriptor by tag number. The given factory is used to
// instantiate a message if the field value is (or contains) a message value.
//
// On error, the field descriptor and value are typically nil. However, if the
// error returned is ErrWireTypeEndGroup, the returned value will indicate any
// tag number encoded in the end-group marker.
//
// If the field descriptor returned is nil, that means that the given function
// returned nil. This is expected to happen for unrecognized tag numbers. In
// that case, no error is returned, and the value will be an UnknownField.
func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
if cb.EOF() {
return nil, nil, io.EOF
}
tagNumber, wireType, err := cb.DecodeTagAndWireType()
if err != nil {
return nil, nil, err
}
if wireType == proto.WireEndGroup {
return nil, tagNumber, ErrWireTypeEndGroup
}
fd := fieldFinder(tagNumber)
if fd == nil {
val, err := cb.decodeUnknownField(tagNumber, wireType)
return nil, val, err
}
val, err := cb.decodeKnownField(fd, wireType, fact)
return fd, val, err
}
// DecodeScalarField extracts a properly-typed value from v. The returned value's
// type depends on the given field descriptor type. It will be the same type as
// generated structs use for the field descriptor's type. Enum types will return
// an int32. If the given field type uses length-delimited encoding (nested
// messages, bytes, and strings), an error is returned.
func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
switch fd.GetType() {
case descriptor.FieldDescriptorProto_TYPE_BOOL:
return v != 0, nil
case descriptor.FieldDescriptorProto_TYPE_UINT32,
descriptor.FieldDescriptorProto_TYPE_FIXED32:
if v > math.MaxUint32 {
return nil, ErrOverflow
}
return uint32(v), nil
case descriptor.FieldDescriptorProto_TYPE_INT32,
descriptor.FieldDescriptorProto_TYPE_ENUM:
s := int64(v)
if s > math.MaxInt32 || s < math.MinInt32 {
return nil, ErrOverflow
}
return int32(s), nil
case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
if v > math.MaxUint32 {
return nil, ErrOverflow
}
return int32(v), nil
case descriptor.FieldDescriptorProto_TYPE_SINT32:
if v > math.MaxUint32 {
return nil, ErrOverflow
}
return DecodeZigZag32(v), nil
case descriptor.FieldDescriptorProto_TYPE_UINT64,
descriptor.FieldDescriptorProto_TYPE_FIXED64:
return v, nil
case descriptor.FieldDescriptorProto_TYPE_INT64,
descriptor.FieldDescriptorProto_TYPE_SFIXED64:
return int64(v), nil
case descriptor.FieldDescriptorProto_TYPE_SINT64:
return DecodeZigZag64(v), nil
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
if v > math.MaxUint32 {
return nil, ErrOverflow
}
return math.Float32frombits(uint32(v)), nil
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
return math.Float64frombits(v), nil
default:
// bytes, string, message, and group cannot be represented as a simple numeric value
return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
}
}
// DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
// returned value's type will usually be []byte, string, or, for nested messages,
// the type returned from the given message factory. However, since repeated
// scalar fields can be length-delimited, when they used packed encoding, it can
// also return an []interface{}, where each element is a scalar value. Furthermore,
// it could return a scalar type, not in a slice, if the given field descriptor is
// not repeated. This is to support cases where a field is changed from optional
// to repeated. New code may emit a packed repeated representation, but old code
// still expects a single scalar value. In this case, if the actual data in bytes
// contains multiple values, only the last value is returned.
func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
switch {
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
return bytes, nil
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
return string(bytes), nil
case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
msg := mf.NewMessage(fd.GetMessageType())
err := proto.Unmarshal(bytes, msg)
if err != nil {
return nil, err
} else {
return msg, nil
}
default:
// even if the field is not repeated or not packed, we still parse it as such for
// backwards compatibility (e.g. message we are de-serializing could have been both
// repeated and packed at the time of serialization)
packedBuf := NewBuffer(bytes)
var slice []interface{}
var val interface{}
for !packedBuf.EOF() {
var v uint64
var err error
if varintTypes[fd.GetType()] {
v, err = packedBuf.DecodeVarint()
} else if fixed32Types[fd.GetType()] {
v, err = packedBuf.DecodeFixed32()
} else if fixed64Types[fd.GetType()] {
v, err = packedBuf.DecodeFixed64()
} else {
return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
}
if err != nil {
return nil, err
}
val, err = DecodeScalarField(fd, v)
if err != nil {
return nil, err
}
if fd.IsRepeated() {
slice = append(slice, val)
}
}
if fd.IsRepeated() {
return slice, nil
} else {
// if not a repeated field, last value wins
return val, nil
}
}
}
func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
var val interface{}
var err error
switch encoding {
case proto.WireFixed32:
var num uint64
num, err = b.DecodeFixed32()
if err == nil {
val, err = DecodeScalarField(fd, num)
}
case proto.WireFixed64:
var num uint64
num, err = b.DecodeFixed64()
if err == nil {
val, err = DecodeScalarField(fd, num)
}
case proto.WireVarint:
var num uint64
num, err = b.DecodeVarint()
if err == nil {
val, err = DecodeScalarField(fd, num)
}
case proto.WireBytes:
alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
var raw []byte
raw, err = b.DecodeRawBytes(alloc)
if err == nil {
val, err = DecodeLengthDelimitedField(fd, raw, fact)
}
case proto.WireStartGroup:
if fd.GetMessageType() == nil {
return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
}
msg := fact.NewMessage(fd.GetMessageType())
var data []byte
data, err = b.ReadGroup(false)
if err == nil {
err = proto.Unmarshal(data, msg)
if err == nil {
val = msg
}
}
default:
return nil, ErrBadWireType
}
if err != nil {
return nil, err
}
return val, nil
}
func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
u := UnknownField{Tag: tagNumber, Encoding: encoding}
var err error
switch encoding {
case proto.WireFixed32:
u.Value, err = b.DecodeFixed32()
case proto.WireFixed64:
u.Value, err = b.DecodeFixed64()
case proto.WireVarint:
u.Value, err = b.DecodeVarint()
case proto.WireBytes:
u.Contents, err = b.DecodeRawBytes(true)
case proto.WireStartGroup:
u.Contents, err = b.ReadGroup(true)
default:
err = ErrBadWireType
}
if err != nil {
return nil, err
}
return u, nil
}