package codec

import (
	"errors"
	"fmt"
	"io"
	"math"

	"github.com/golang/protobuf/proto"
	"github.com/golang/protobuf/protoc-gen-go/descriptor"

	"github.com/jhump/protoreflect/desc"
)

// ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
// it reads indicates an end-group marker.
var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")

// MessageFactory is used to instantiate messages when DecodeFieldValue needs to
// decode a message value.
//
// Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
// implements this interface.
type MessageFactory interface {
	NewMessage(md *desc.MessageDescriptor) proto.Message
}

// UnknownField represents a field that was parsed from the binary wire
// format for a message, but was not a recognized field number. Enough
// information is preserved so that re-serializing the message won't lose
// any of the unrecognized data.
type UnknownField struct {
	// The tag number for the unrecognized field.
	Tag int32

	// Encoding indicates how the unknown field was encoded on the wire. If it
	// is proto.WireBytes or proto.WireGroupStart then Contents will be set to
	// the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
	// significant 32 bits of Value. Otherwise, the data is in all 64 bits of
	// Value.
	Encoding int8
	Contents []byte
	Value    uint64
}

// DecodeFieldValue will read a field value from the buffer and return its
// value and the corresponding field descriptor. The given function is used
// to lookup a field descriptor by tag number. The given factory is used to
// instantiate a message if the field value is (or contains) a message value.
//
// On error, the field descriptor and value are typically nil. However, if the
// error returned is ErrWireTypeEndGroup, the returned value will indicate any
// tag number encoded in the end-group marker.
//
// If the field descriptor returned is nil, that means that the given function
// returned nil. This is expected to happen for unrecognized tag numbers. In
// that case, no error is returned, and the value will be an UnknownField.
func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
	if cb.EOF() {
		return nil, nil, io.EOF
	}
	tagNumber, wireType, err := cb.DecodeTagAndWireType()
	if err != nil {
		return nil, nil, err
	}
	if wireType == proto.WireEndGroup {
		return nil, tagNumber, ErrWireTypeEndGroup
	}
	fd := fieldFinder(tagNumber)
	if fd == nil {
		val, err := cb.decodeUnknownField(tagNumber, wireType)
		return nil, val, err
	}
	val, err := cb.decodeKnownField(fd, wireType, fact)
	return fd, val, err
}

// DecodeScalarField extracts a properly-typed value from v. The returned value's
// type depends on the given field descriptor type. It will be the same type as
// generated structs use for the field descriptor's type. Enum types will return
// an int32. If the given field type uses length-delimited encoding (nested
// messages, bytes, and strings), an error is returned.
func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
	switch fd.GetType() {
	case descriptor.FieldDescriptorProto_TYPE_BOOL:
		return v != 0, nil
	case descriptor.FieldDescriptorProto_TYPE_UINT32,
		descriptor.FieldDescriptorProto_TYPE_FIXED32:
		if v > math.MaxUint32 {
			return nil, ErrOverflow
		}
		return uint32(v), nil

	case descriptor.FieldDescriptorProto_TYPE_INT32,
		descriptor.FieldDescriptorProto_TYPE_ENUM:
		s := int64(v)
		if s > math.MaxInt32 || s < math.MinInt32 {
			return nil, ErrOverflow
		}
		return int32(s), nil

	case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
		if v > math.MaxUint32 {
			return nil, ErrOverflow
		}
		return int32(v), nil

	case descriptor.FieldDescriptorProto_TYPE_SINT32:
		if v > math.MaxUint32 {
			return nil, ErrOverflow
		}
		return DecodeZigZag32(v), nil

	case descriptor.FieldDescriptorProto_TYPE_UINT64,
		descriptor.FieldDescriptorProto_TYPE_FIXED64:
		return v, nil

	case descriptor.FieldDescriptorProto_TYPE_INT64,
		descriptor.FieldDescriptorProto_TYPE_SFIXED64:
		return int64(v), nil

	case descriptor.FieldDescriptorProto_TYPE_SINT64:
		return DecodeZigZag64(v), nil

	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
		if v > math.MaxUint32 {
			return nil, ErrOverflow
		}
		return math.Float32frombits(uint32(v)), nil

	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
		return math.Float64frombits(v), nil

	default:
		// bytes, string, message, and group cannot be represented as a simple numeric value
		return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
	}
}

// DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
// returned value's type will usually be []byte, string, or, for nested messages,
// the type returned from the given message factory. However, since repeated
// scalar fields can be length-delimited, when they used packed encoding, it can
// also return an []interface{}, where each element is a scalar value. Furthermore,
// it could return a scalar type, not in a slice, if the given field descriptor is
// not repeated. This is to support cases where a field is changed from optional
// to repeated. New code may emit a packed repeated representation, but old code
// still expects a single scalar value. In this case, if the actual data in bytes
// contains multiple values, only the last value is returned.
func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
	switch {
	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
		return bytes, nil

	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
		return string(bytes), nil

	case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
		fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
		msg := mf.NewMessage(fd.GetMessageType())
		err := proto.Unmarshal(bytes, msg)
		if err != nil {
			return nil, err
		} else {
			return msg, nil
		}

	default:
		// even if the field is not repeated or not packed, we still parse it as such for
		// backwards compatibility (e.g. message we are de-serializing could have been both
		// repeated and packed at the time of serialization)
		packedBuf := NewBuffer(bytes)
		var slice []interface{}
		var val interface{}
		for !packedBuf.EOF() {
			var v uint64
			var err error
			if varintTypes[fd.GetType()] {
				v, err = packedBuf.DecodeVarint()
			} else if fixed32Types[fd.GetType()] {
				v, err = packedBuf.DecodeFixed32()
			} else if fixed64Types[fd.GetType()] {
				v, err = packedBuf.DecodeFixed64()
			} else {
				return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
			}
			if err != nil {
				return nil, err
			}
			val, err = DecodeScalarField(fd, v)
			if err != nil {
				return nil, err
			}
			if fd.IsRepeated() {
				slice = append(slice, val)
			}
		}
		if fd.IsRepeated() {
			return slice, nil
		} else {
			// if not a repeated field, last value wins
			return val, nil
		}
	}
}

func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
	var val interface{}
	var err error
	switch encoding {
	case proto.WireFixed32:
		var num uint64
		num, err = b.DecodeFixed32()
		if err == nil {
			val, err = DecodeScalarField(fd, num)
		}
	case proto.WireFixed64:
		var num uint64
		num, err = b.DecodeFixed64()
		if err == nil {
			val, err = DecodeScalarField(fd, num)
		}
	case proto.WireVarint:
		var num uint64
		num, err = b.DecodeVarint()
		if err == nil {
			val, err = DecodeScalarField(fd, num)
		}

	case proto.WireBytes:
		alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
		var raw []byte
		raw, err = b.DecodeRawBytes(alloc)
		if err == nil {
			val, err = DecodeLengthDelimitedField(fd, raw, fact)
		}

	case proto.WireStartGroup:
		if fd.GetMessageType() == nil {
			return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
		}
		msg := fact.NewMessage(fd.GetMessageType())
		var data []byte
		data, err = b.ReadGroup(false)
		if err == nil {
			err = proto.Unmarshal(data, msg)
			if err == nil {
				val = msg
			}
		}

	default:
		return nil, ErrBadWireType
	}
	if err != nil {
		return nil, err
	}

	return val, nil
}

func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
	u := UnknownField{Tag: tagNumber, Encoding: encoding}
	var err error
	switch encoding {
	case proto.WireFixed32:
		u.Value, err = b.DecodeFixed32()
	case proto.WireFixed64:
		u.Value, err = b.DecodeFixed64()
	case proto.WireVarint:
		u.Value, err = b.DecodeVarint()
	case proto.WireBytes:
		u.Contents, err = b.DecodeRawBytes(true)
	case proto.WireStartGroup:
		u.Contents, err = b.ReadGroup(true)
	default:
		err = ErrBadWireType
	}
	if err != nil {
		return nil, err
	}
	return u, nil
}
