package codec

import (
	"fmt"
	"math"
	"reflect"
	"sort"

	"github.com/golang/protobuf/proto"
	"github.com/golang/protobuf/protoc-gen-go/descriptor"

	"github.com/jhump/protoreflect/desc"
)

func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
	if fd.IsMap() {
		mp := val.(map[interface{}]interface{})
		entryType := fd.GetMessageType()
		keyType := entryType.FindFieldByNumber(1)
		valType := entryType.FindFieldByNumber(2)
		var entryBuffer Buffer
		if cb.deterministic {
			keys := make([]interface{}, 0, len(mp))
			for k := range mp {
				keys = append(keys, k)
			}
			sort.Sort(sortable(keys))
			for _, k := range keys {
				v := mp[k]
				entryBuffer.Reset()
				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
					return err
				}
				if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
					return err
				}
				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
					return err
				}
				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
					return err
				}
			}
		} else {
			for k, v := range mp {
				entryBuffer.Reset()
				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
					return err
				}
				if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
					return err
				}
				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
					return err
				}
				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
					return err
				}
			}
		}
		return nil
	} else if fd.IsRepeated() {
		sl := val.([]interface{})
		wt, err := getWireType(fd.GetType())
		if err != nil {
			return err
		}
		if isPacked(fd) && len(sl) > 1 &&
			(wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
			// packed repeated field
			var packedBuffer Buffer
			for _, v := range sl {
				if err := packedBuffer.encodeFieldValue(fd, v); err != nil {
					return err
				}
			}
			if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
				return err
			}
			return cb.EncodeRawBytes(packedBuffer.Bytes())
		} else {
			// non-packed repeated field
			for _, v := range sl {
				if err := cb.encodeFieldElement(fd, v); err != nil {
					return err
				}
			}
			return nil
		}
	} else {
		return cb.encodeFieldElement(fd, val)
	}
}

func isPacked(fd *desc.FieldDescriptor) bool {
	opts := fd.AsFieldDescriptorProto().GetOptions()
	// if set, use that value
	if opts != nil && opts.Packed != nil {
		return opts.GetPacked()
	}
	// if unset: proto2 defaults to false, proto3 to true
	return fd.GetFile().IsProto3()
}

// sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
// bools, or strings.
type sortable []interface{}

func (s sortable) Len() int {
	return len(s)
}

func (s sortable) Less(i, j int) bool {
	vi := s[i]
	vj := s[j]
	switch reflect.TypeOf(vi).Kind() {
	case reflect.Int32:
		return vi.(int32) < vj.(int32)
	case reflect.Int64:
		return vi.(int64) < vj.(int64)
	case reflect.Uint32:
		return vi.(uint32) < vj.(uint32)
	case reflect.Uint64:
		return vi.(uint64) < vj.(uint64)
	case reflect.String:
		return vi.(string) < vj.(string)
	case reflect.Bool:
		return !vi.(bool) && vj.(bool)
	default:
		panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
	}
}

func (s sortable) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error {
	wt, err := getWireType(fd.GetType())
	if err != nil {
		return err
	}
	if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil {
		return err
	}
	if err := b.encodeFieldValue(fd, val); err != nil {
		return err
	}
	if wt == proto.WireStartGroup {
		return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup)
	}
	return nil
}

func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
	switch fd.GetType() {
	case descriptor.FieldDescriptorProto_TYPE_BOOL:
		v := val.(bool)
		if v {
			return b.EncodeVarint(1)
		}
		return b.EncodeVarint(0)

	case descriptor.FieldDescriptorProto_TYPE_ENUM,
		descriptor.FieldDescriptorProto_TYPE_INT32:
		v := val.(int32)
		return b.EncodeVarint(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
		v := val.(int32)
		return b.EncodeFixed32(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_SINT32:
		v := val.(int32)
		return b.EncodeVarint(EncodeZigZag32(v))

	case descriptor.FieldDescriptorProto_TYPE_UINT32:
		v := val.(uint32)
		return b.EncodeVarint(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_FIXED32:
		v := val.(uint32)
		return b.EncodeFixed32(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_INT64:
		v := val.(int64)
		return b.EncodeVarint(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
		v := val.(int64)
		return b.EncodeFixed64(uint64(v))

	case descriptor.FieldDescriptorProto_TYPE_SINT64:
		v := val.(int64)
		return b.EncodeVarint(EncodeZigZag64(v))

	case descriptor.FieldDescriptorProto_TYPE_UINT64:
		v := val.(uint64)
		return b.EncodeVarint(v)

	case descriptor.FieldDescriptorProto_TYPE_FIXED64:
		v := val.(uint64)
		return b.EncodeFixed64(v)

	case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
		v := val.(float64)
		return b.EncodeFixed64(math.Float64bits(v))

	case descriptor.FieldDescriptorProto_TYPE_FLOAT:
		v := val.(float32)
		return b.EncodeFixed32(uint64(math.Float32bits(v)))

	case descriptor.FieldDescriptorProto_TYPE_BYTES:
		v := val.([]byte)
		return b.EncodeRawBytes(v)

	case descriptor.FieldDescriptorProto_TYPE_STRING:
		v := val.(string)
		return b.EncodeRawBytes(([]byte)(v))

	case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
		return b.EncodeDelimitedMessage(val.(proto.Message))

	case descriptor.FieldDescriptorProto_TYPE_GROUP:
		// just append the nested message to this buffer
		return b.EncodeMessage(val.(proto.Message))
		// whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag

	default:
		return fmt.Errorf("unrecognized field type: %v", fd.GetType())
	}
}

func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) {
	switch t {
	case descriptor.FieldDescriptorProto_TYPE_ENUM,
		descriptor.FieldDescriptorProto_TYPE_BOOL,
		descriptor.FieldDescriptorProto_TYPE_INT32,
		descriptor.FieldDescriptorProto_TYPE_SINT32,
		descriptor.FieldDescriptorProto_TYPE_UINT32,
		descriptor.FieldDescriptorProto_TYPE_INT64,
		descriptor.FieldDescriptorProto_TYPE_SINT64,
		descriptor.FieldDescriptorProto_TYPE_UINT64:
		return proto.WireVarint, nil

	case descriptor.FieldDescriptorProto_TYPE_FIXED32,
		descriptor.FieldDescriptorProto_TYPE_SFIXED32,
		descriptor.FieldDescriptorProto_TYPE_FLOAT:
		return proto.WireFixed32, nil

	case descriptor.FieldDescriptorProto_TYPE_FIXED64,
		descriptor.FieldDescriptorProto_TYPE_SFIXED64,
		descriptor.FieldDescriptorProto_TYPE_DOUBLE:
		return proto.WireFixed64, nil

	case descriptor.FieldDescriptorProto_TYPE_BYTES,
		descriptor.FieldDescriptorProto_TYPE_STRING,
		descriptor.FieldDescriptorProto_TYPE_MESSAGE:
		return proto.WireBytes, nil

	case descriptor.FieldDescriptorProto_TYPE_GROUP:
		return proto.WireStartGroup, nil

	default:
		return 0, ErrBadWireType
	}
}
