// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package bsonx

import (
	"errors"
	"fmt"
	"reflect"

	"github.com/mongodb/mongo-go-driver/bson/bsoncodec"
	"github.com/mongodb/mongo-go-driver/bson/bsonrw"
	"github.com/mongodb/mongo-go-driver/bson/bsontype"
)

var primitiveCodecs PrimitiveCodecs

var tDocument = reflect.TypeOf((Doc)(nil))
var tMDoc = reflect.TypeOf((MDoc)(nil))
var tArray = reflect.TypeOf((Arr)(nil))
var tValue = reflect.TypeOf(Val{})
var tElementSlice = reflect.TypeOf(([]Elem)(nil))

// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
// defined in this package.
type PrimitiveCodecs struct{}

// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
	if rb == nil {
		panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
	}

	rb.
		RegisterEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)).
		RegisterEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)).
		RegisterEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)).
		RegisterEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)).
		RegisterDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)).
		RegisterDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)).
		RegisterDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)).
		RegisterDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue))
}

// DocumentEncodeValue is the ValueEncoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
	if !val.IsValid() || val.Type() != tDocument {
		return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
	}

	if val.IsNil() {
		return vw.WriteNull()
	}

	doc := val.Interface().(Doc)

	dw, err := vw.WriteDocument()
	if err != nil {
		return err
	}

	return pc.encodeDocument(ec, dw, doc)
}

// DocumentDecodeValue is the ValueDecoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
	if !val.CanSet() || val.Type() != tDocument {
		return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
	}

	return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc))
}

func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error {

	dr, err := vr.ReadDocument()
	if err != nil {
		return err
	}

	return pc.decodeDocument(dctx, dr, doc)
}

// ArrayEncodeValue is the ValueEncoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
	if !val.IsValid() || val.Type() != tArray {
		return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
	}

	if val.IsNil() {
		return vw.WriteNull()
	}

	arr := val.Interface().(Arr)

	aw, err := vw.WriteArray()
	if err != nil {
		return err
	}

	for _, val := range arr {
		dvw, err := aw.WriteArrayElement()
		if err != nil {
			return err
		}

		err = pc.encodeValue(ec, dvw, val)

		if err != nil {
			return err
		}
	}

	return aw.WriteArrayEnd()
}

// ArrayDecodeValue is the ValueDecoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
	if !val.CanSet() || val.Type() != tArray {
		return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
	}

	ar, err := vr.ReadArray()
	if err != nil {
		return err
	}

	if val.IsNil() {
		val.Set(reflect.MakeSlice(tArray, 0, 0))
	}
	val.SetLen(0)

	for {
		vr, err := ar.ReadValue()
		if err == bsonrw.ErrEOA {
			break
		}
		if err != nil {
			return err
		}

		var elem Val
		err = pc.valueDecodeValue(dc, vr, &elem)
		if err != nil {
			return err
		}

		val.Set(reflect.Append(val, reflect.ValueOf(elem)))
	}

	return nil
}

// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
	if !val.IsValid() || val.Type() != tElementSlice {
		return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
	}

	if val.IsNil() {
		return vw.WriteNull()
	}

	return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument))
}

// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
	if !val.CanSet() || val.Type() != tElementSlice {
		return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
	}

	if val.IsNil() {
		val.Set(reflect.MakeSlice(val.Type(), 0, 0))
	}

	val.SetLen(0)

	dr, err := vr.ReadDocument()
	if err != nil {
		return err
	}
	elems := make([]reflect.Value, 0)
	for {
		key, vr, err := dr.ReadElement()
		if err == bsonrw.ErrEOD {
			break
		}
		if err != nil {
			return err
		}

		var elem Elem
		err = pc.elementDecodeValue(dc, vr, key, &elem)
		if err != nil {
			return err
		}

		elems = append(elems, reflect.ValueOf(elem))
	}

	val.Set(reflect.Append(val, elems...))
	return nil
}

// ValueEncodeValue is the ValueEncoderFunc for *Value.
func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
	if !val.IsValid() || val.Type() != tValue {
		return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val}
	}

	v := val.Interface().(Val)

	return pc.encodeValue(ec, vw, v)
}

// ValueDecodeValue is the ValueDecoderFunc for *Value.
func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
	if !val.CanSet() || val.Type() != tValue {
		return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val}
	}

	return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val))
}

// encodeDocument is a separate function that we use because CodeWithScope
// returns us a DocumentWriter and we need to do the same logic that we would do
// for a document but cannot use a Codec.
func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error {
	for _, elem := range doc {
		dvw, err := dw.WriteDocumentElement(elem.Key)
		if err != nil {
			return err
		}

		err = pc.encodeValue(ec, dvw, elem.Value)

		if err != nil {
			return err
		}
	}

	return dw.WriteDocumentEnd()
}

// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader.
func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
	return pc.decodeDocument(dctx, dr, pdoc)
}

func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
	if *pdoc == nil {
		*pdoc = make(Doc, 0)
	}
	*pdoc = (*pdoc)[:0]
	for {
		key, vr, err := dr.ReadElement()
		if err == bsonrw.ErrEOD {
			break
		}
		if err != nil {
			return err
		}

		var elem Elem
		err = pc.elementDecodeValue(dctx, vr, key, &elem)
		if err != nil {
			return err
		}

		*pdoc = append(*pdoc, elem)
	}
	return nil
}

func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error {
	var val Val
	switch vr.Type() {
	case bsontype.Double:
		f64, err := vr.ReadDouble()
		if err != nil {
			return err
		}
		val = Double(f64)
	case bsontype.String:
		str, err := vr.ReadString()
		if err != nil {
			return err
		}
		val = String(str)
	case bsontype.EmbeddedDocument:
		var embeddedDoc Doc
		err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
		if err != nil {
			return err
		}
		val = Document(embeddedDoc)
	case bsontype.Array:
		arr := reflect.New(tArray).Elem()
		err := pc.ArrayDecodeValue(dc, vr, arr)
		if err != nil {
			return err
		}
		val = Array(arr.Interface().(Arr))
	case bsontype.Binary:
		data, subtype, err := vr.ReadBinary()
		if err != nil {
			return err
		}
		val = Binary(subtype, data)
	case bsontype.Undefined:
		err := vr.ReadUndefined()
		if err != nil {
			return err
		}
		val = Undefined()
	case bsontype.ObjectID:
		oid, err := vr.ReadObjectID()
		if err != nil {
			return err
		}
		val = ObjectID(oid)
	case bsontype.Boolean:
		b, err := vr.ReadBoolean()
		if err != nil {
			return err
		}
		val = Boolean(b)
	case bsontype.DateTime:
		dt, err := vr.ReadDateTime()
		if err != nil {
			return err
		}
		val = DateTime(dt)
	case bsontype.Null:
		err := vr.ReadNull()
		if err != nil {
			return err
		}
		val = Null()
	case bsontype.Regex:
		pattern, options, err := vr.ReadRegex()
		if err != nil {
			return err
		}
		val = Regex(pattern, options)
	case bsontype.DBPointer:
		ns, pointer, err := vr.ReadDBPointer()
		if err != nil {
			return err
		}
		val = DBPointer(ns, pointer)
	case bsontype.JavaScript:
		js, err := vr.ReadJavascript()
		if err != nil {
			return err
		}
		val = JavaScript(js)
	case bsontype.Symbol:
		symbol, err := vr.ReadSymbol()
		if err != nil {
			return err
		}
		val = Symbol(symbol)
	case bsontype.CodeWithScope:
		code, scope, err := vr.ReadCodeWithScope()
		if err != nil {
			return err
		}
		var doc Doc
		err = pc.decodeDocument(dc, scope, &doc)
		if err != nil {
			return err
		}
		val = CodeWithScope(code, doc)
	case bsontype.Int32:
		i32, err := vr.ReadInt32()
		if err != nil {
			return err
		}
		val = Int32(i32)
	case bsontype.Timestamp:
		t, i, err := vr.ReadTimestamp()
		if err != nil {
			return err
		}
		val = Timestamp(t, i)
	case bsontype.Int64:
		i64, err := vr.ReadInt64()
		if err != nil {
			return err
		}
		val = Int64(i64)
	case bsontype.Decimal128:
		d128, err := vr.ReadDecimal128()
		if err != nil {
			return err
		}
		val = Decimal128(d128)
	case bsontype.MinKey:
		err := vr.ReadMinKey()
		if err != nil {
			return err
		}
		val = MinKey()
	case bsontype.MaxKey:
		err := vr.ReadMaxKey()
		if err != nil {
			return err
		}
		val = MaxKey()
	default:
		return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
	}

	*elem = Elem{Key: key, Value: val}
	return nil
}

// encodeValue does not validation, and the callers must perform validation on val before calling
// this method.
func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error {
	var err error
	switch val.Type() {
	case bsontype.Double:
		err = vw.WriteDouble(val.Double())
	case bsontype.String:
		err = vw.WriteString(val.StringValue())
	case bsontype.EmbeddedDocument:
		var encoder bsoncodec.ValueEncoder
		encoder, err = ec.LookupEncoder(tDocument)
		if err != nil {
			break
		}
		err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document()))
	case bsontype.Array:
		var encoder bsoncodec.ValueEncoder
		encoder, err = ec.LookupEncoder(tArray)
		if err != nil {
			break
		}
		err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array()))
	case bsontype.Binary:
		// TODO: FIX THIS (╯°□°）╯︵ ┻━┻
		subtype, data := val.Binary()
		err = vw.WriteBinaryWithSubtype(data, subtype)
	case bsontype.Undefined:
		err = vw.WriteUndefined()
	case bsontype.ObjectID:
		err = vw.WriteObjectID(val.ObjectID())
	case bsontype.Boolean:
		err = vw.WriteBoolean(val.Boolean())
	case bsontype.DateTime:
		err = vw.WriteDateTime(val.DateTime())
	case bsontype.Null:
		err = vw.WriteNull()
	case bsontype.Regex:
		err = vw.WriteRegex(val.Regex())
	case bsontype.DBPointer:
		err = vw.WriteDBPointer(val.DBPointer())
	case bsontype.JavaScript:
		err = vw.WriteJavascript(val.JavaScript())
	case bsontype.Symbol:
		err = vw.WriteSymbol(val.Symbol())
	case bsontype.CodeWithScope:
		code, scope := val.CodeWithScope()

		var cwsw bsonrw.DocumentWriter
		cwsw, err = vw.WriteCodeWithScope(code)
		if err != nil {
			break
		}

		err = pc.encodeDocument(ec, cwsw, scope)
	case bsontype.Int32:
		err = vw.WriteInt32(val.Int32())
	case bsontype.Timestamp:
		err = vw.WriteTimestamp(val.Timestamp())
	case bsontype.Int64:
		err = vw.WriteInt64(val.Int64())
	case bsontype.Decimal128:
		err = vw.WriteDecimal128(val.Decimal128())
	case bsontype.MinKey:
		err = vw.WriteMinKey()
	case bsontype.MaxKey:
		err = vw.WriteMaxKey()
	default:
		err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type())
	}

	return err
}

func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error {
	switch vr.Type() {
	case bsontype.Double:
		f64, err := vr.ReadDouble()
		if err != nil {
			return err
		}
		*val = Double(f64)
	case bsontype.String:
		str, err := vr.ReadString()
		if err != nil {
			return err
		}
		*val = String(str)
	case bsontype.EmbeddedDocument:
		var embeddedDoc Doc
		err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
		if err != nil {
			return err
		}
		*val = Document(embeddedDoc)
	case bsontype.Array:
		arr := reflect.New(tArray).Elem()
		err := pc.ArrayDecodeValue(dc, vr, arr)
		if err != nil {
			return err
		}
		*val = Array(arr.Interface().(Arr))
	case bsontype.Binary:
		data, subtype, err := vr.ReadBinary()
		if err != nil {
			return err
		}
		*val = Binary(subtype, data)
	case bsontype.Undefined:
		err := vr.ReadUndefined()
		if err != nil {
			return err
		}
		*val = Undefined()
	case bsontype.ObjectID:
		oid, err := vr.ReadObjectID()
		if err != nil {
			return err
		}
		*val = ObjectID(oid)
	case bsontype.Boolean:
		b, err := vr.ReadBoolean()
		if err != nil {
			return err
		}
		*val = Boolean(b)
	case bsontype.DateTime:
		dt, err := vr.ReadDateTime()
		if err != nil {
			return err
		}
		*val = DateTime(dt)
	case bsontype.Null:
		err := vr.ReadNull()
		if err != nil {
			return err
		}
		*val = Null()
	case bsontype.Regex:
		pattern, options, err := vr.ReadRegex()
		if err != nil {
			return err
		}
		*val = Regex(pattern, options)
	case bsontype.DBPointer:
		ns, pointer, err := vr.ReadDBPointer()
		if err != nil {
			return err
		}
		*val = DBPointer(ns, pointer)
	case bsontype.JavaScript:
		js, err := vr.ReadJavascript()
		if err != nil {
			return err
		}
		*val = JavaScript(js)
	case bsontype.Symbol:
		symbol, err := vr.ReadSymbol()
		if err != nil {
			return err
		}
		*val = Symbol(symbol)
	case bsontype.CodeWithScope:
		code, scope, err := vr.ReadCodeWithScope()
		if err != nil {
			return err
		}
		var scopeDoc Doc
		err = pc.decodeDocument(dc, scope, &scopeDoc)
		if err != nil {
			return err
		}
		*val = CodeWithScope(code, scopeDoc)
	case bsontype.Int32:
		i32, err := vr.ReadInt32()
		if err != nil {
			return err
		}
		*val = Int32(i32)
	case bsontype.Timestamp:
		t, i, err := vr.ReadTimestamp()
		if err != nil {
			return err
		}
		*val = Timestamp(t, i)
	case bsontype.Int64:
		i64, err := vr.ReadInt64()
		if err != nil {
			return err
		}
		*val = Int64(i64)
	case bsontype.Decimal128:
		d128, err := vr.ReadDecimal128()
		if err != nil {
			return err
		}
		*val = Decimal128(d128)
	case bsontype.MinKey:
		err := vr.ReadMinKey()
		if err != nil {
			return err
		}
		*val = MinKey()
	case bsontype.MaxKey:
		err := vr.ReadMaxKey()
		if err != nil {
			return err
		}
		*val = MaxKey()
	default:
		return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
	}

	return nil
}
