// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package command

import (
	"errors"

	"context"

	"fmt"

	"github.com/mongodb/mongo-go-driver/bson"
	"github.com/mongodb/mongo-go-driver/bson/bsontype"
	"github.com/mongodb/mongo-go-driver/bson/primitive"
	"github.com/mongodb/mongo-go-driver/mongo/readconcern"
	"github.com/mongodb/mongo-go-driver/mongo/writeconcern"
	"github.com/mongodb/mongo-go-driver/x/bsonx"
	"github.com/mongodb/mongo-go-driver/x/bsonx/bsoncore"
	"github.com/mongodb/mongo-go-driver/x/mongo/driver/session"
	"github.com/mongodb/mongo-go-driver/x/network/description"
	"github.com/mongodb/mongo-go-driver/x/network/result"
	"github.com/mongodb/mongo-go-driver/x/network/wiremessage"
)

// WriteBatch represents a single batch for a write operation.
type WriteBatch struct {
	*Write
	numDocs int
}

// DecodeError attempts to decode the wiremessage as an error
func DecodeError(wm wiremessage.WireMessage) error {
	var rdr bson.Raw
	switch msg := wm.(type) {
	case wiremessage.Msg:
		for _, section := range msg.Sections {
			switch converted := section.(type) {
			case wiremessage.SectionBody:
				rdr = converted.Document
			}
		}
	case wiremessage.Reply:
		if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
			return nil
		}
		rdr = msg.Documents[0]
	}

	err := rdr.Validate()
	if err != nil {
		return nil
	}

	extractedError := extractError(rdr)

	// If parsed successfully return the error
	if _, ok := extractedError.(Error); ok {
		return err
	}

	return nil
}

// helper method to extract an error from a reader if there is one; first returned item is the
// error if it exists, the second holds parsing errors
func extractError(rdr bson.Raw) error {
	var errmsg, codeName string
	var code int32
	var labels []string
	elems, err := rdr.Elements()
	if err != nil {
		return err
	}

	for _, elem := range elems {
		switch elem.Key() {
		case "ok":
			switch elem.Value().Type {
			case bson.TypeInt32:
				if elem.Value().Int32() == 1 {
					return nil
				}
			case bson.TypeInt64:
				if elem.Value().Int64() == 1 {
					return nil
				}
			case bson.TypeDouble:
				if elem.Value().Double() == 1 {
					return nil
				}
			}
		case "errmsg":
			if str, okay := elem.Value().StringValueOK(); okay {
				errmsg = str
			}
		case "codeName":
			if str, okay := elem.Value().StringValueOK(); okay {
				codeName = str
			}
		case "code":
			if c, okay := elem.Value().Int32OK(); okay {
				code = c
			}
		case "errorLabels":
			if arr, okay := elem.Value().ArrayOK(); okay {
				elems, err := arr.Elements()
				if err != nil {
					continue
				}
				for _, elem := range elems {
					if str, ok := elem.Value().StringValueOK(); ok {
						labels = append(labels, str)
					}
				}

			}
		}
	}

	if errmsg == "" {
		errmsg = "command failed"
	}

	return Error{
		Code:    code,
		Message: errmsg,
		Name:    codeName,
		Labels:  labels,
	}
}

func responseClusterTime(response bson.Raw) bson.Raw {
	clusterTime, err := response.LookupErr("$clusterTime")
	if err != nil {
		// $clusterTime not included by the server
		return nil
	}
	idx, doc := bsoncore.AppendDocumentStart(nil)
	doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
	doc = append(doc, clusterTime.Value...)
	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
	return doc
}

func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
	clusterTime := responseClusterTime(response)
	if clusterTime == nil {
		return nil
	}

	if sess != nil {
		err := sess.AdvanceClusterTime(clusterTime)
		if err != nil {
			return err
		}
	}

	if clock != nil {
		clock.AdvanceClusterTime(clusterTime)
	}

	return nil
}

func updateOperationTime(sess *session.Client, response bson.Raw) error {
	if sess == nil {
		return nil
	}

	opTimeElem, err := response.LookupErr("operationTime")
	if err != nil {
		// operationTime not included by the server
		return nil
	}

	t, i := opTimeElem.Timestamp()
	return sess.AdvanceOperationTime(&primitive.Timestamp{
		T: t,
		I: i,
	})
}

func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
	if cmd == nil {
		return bson.Raw{5, 0, 0, 0, 0}, nil
	}

	return cmd.MarshalBSON()
}

// adds session related fields to a BSON doc representing a command
func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
	if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
		return cmd, nil
	}

	if client.Terminated {
		return cmd, session.ErrSessionEnded
	}

	if _, err := cmd.LookupElementErr("lsid"); err != nil {
		cmd = cmd.Delete("lsid")
	}

	cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})

	if client.TransactionRunning() ||
		client.RetryingCommit {
		cmd = addTransaction(cmd, client)
	}

	client.ApplyCommand() // advance the state machine based on a command executing

	return cmd, nil
}

// if in a transaction, add the transaction fields
func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
	cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
	if client.TransactionStarting() {
		// When starting transaction, always transition to the next state, even on error
		cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
	}
	return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
}

func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
	if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
		return cmd
	}

	var clusterTime bson.Raw
	if clock != nil {
		clusterTime = clock.GetClusterTime()
	}

	if sess != nil {
		if clusterTime == nil {
			clusterTime = sess.ClusterTime
		} else {
			clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
		}
	}

	if clusterTime == nil {
		return cmd
	}

	d, err := bsonx.ReadDoc(clusterTime)
	if err != nil {
		return cmd // broken clusterTime
	}

	cmd = cmd.Delete("$clusterTime")

	return append(cmd, d...)
}

// add a read concern to a BSON doc representing a command
func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
	// Starting transaction's read concern overrides all others
	if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
		rc = sess.CurrentRc
	}

	// start transaction must append afterclustertime IF causally consistent and operation time exists
	if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
		rc = readconcern.New()
	}

	if rc == nil {
		return cmd, nil
	}

	t, data, err := rc.MarshalBSONValue()
	if err != nil {
		return cmd, err
	}

	var rcDoc bsonx.Doc
	err = rcDoc.UnmarshalBSONValue(t, data)
	if err != nil {
		return cmd, err
	}
	if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
		rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
	}

	cmd = cmd.Delete("readConcern")

	if len(rcDoc) != 0 {
		cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
	}
	return cmd, nil
}

// add a write concern to a BSON doc representing a command
func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
	if wc == nil {
		return cmd, nil
	}

	t, data, err := wc.MarshalBSONValue()
	if err != nil {
		if err == writeconcern.ErrEmptyWriteConcern {
			return cmd, nil
		}
		return cmd, err
	}

	var xval bsonx.Val
	err = xval.UnmarshalBSONValue(t, data)
	if err != nil {
		return cmd, err
	}

	// delete if doc already has write concern
	cmd = cmd.Delete("writeConcern")

	return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
}

// Get the error labels from a command response
func getErrorLabels(rdr *bson.Raw) ([]string, error) {
	var labels []string
	labelsElem, err := rdr.LookupErr("errorLabels")
	if err != bsoncore.ErrElementNotFound {
		return nil, err
	}
	if labelsElem.Type == bsontype.Array {
		labelsIt, err := labelsElem.Array().Elements()
		if err != nil {
			return nil, err
		}
		for _, elem := range labelsIt {
			labels = append(labels, elem.Value().StringValue())
		}
	}
	return labels, nil
}

// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
// as a Section 1 payload in OP_MSG
func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
	var array bsonx.Arr
	var id string

	keys := []string{"documents", "updates", "deletes"}

	for _, key := range keys {
		val, err := cmd.LookupErr(key)
		if err != nil {
			continue
		}

		array = val.Array()
		cmd = cmd.Delete(key)
		id = key
		break
	}

	return cmd, array, id
}

// Add the $db and $readPreference keys to the command
// If the command has no read preference, pass nil for rpDoc
func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
	cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
	if rpDoc != nil {
		cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
	}

	return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
}

func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
	docSequence := wiremessage.SectionDocumentSequence{
		PayloadType: wiremessage.DocumentSequence,
		Identifier:  identifier,
		Documents:   make([]bson.Raw, 0, len(arr)),
	}

	for _, val := range arr {
		d, _ := val.Document().MarshalBSON()
		docSequence.Documents = append(docSequence.Documents, d)
	}

	docSequence.Size = int32(docSequence.PayloadLen())
	return docSequence, nil
}

func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
	batches := [][]bsonx.Doc{}

	if targetBatchSize > reservedCommandBufferBytes {
		targetBatchSize -= reservedCommandBufferBytes
	}

	if maxCount <= 0 {
		maxCount = 1
	}

	startAt := 0
splitInserts:
	for {
		size := 0
		batch := []bsonx.Doc{}
	assembleBatch:
		for idx := startAt; idx < len(docs); idx++ {
			raw, _ := docs[idx].MarshalBSON()

			if len(raw) > targetBatchSize {
				return nil, ErrDocumentTooLarge
			}
			if size+len(raw) > targetBatchSize {
				break assembleBatch
			}

			size += len(raw)
			batch = append(batch, docs[idx])
			startAt++
			if len(batch) == maxCount {
				break assembleBatch
			}
		}
		batches = append(batches, batch)
		if startAt == len(docs) {
			break splitInserts
		}
	}

	return batches, nil
}

func encodeBatch(
	docs []bsonx.Doc,
	opts []bsonx.Elem,
	cmdKind WriteCommandKind,
	collName string,
) (bsonx.Doc, error) {
	var cmdName string
	var docString string

	switch cmdKind {
	case InsertCommand:
		cmdName = "insert"
		docString = "documents"
	case UpdateCommand:
		cmdName = "update"
		docString = "updates"
	case DeleteCommand:
		cmdName = "delete"
		docString = "deletes"
	}

	cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}

	vals := make(bsonx.Arr, 0, len(docs))
	for _, doc := range docs {
		vals = append(vals, bsonx.Document(doc))
	}
	cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
	cmd = append(cmd, opts...)

	return cmd, nil
}

// converts batches of Write Commands to wire messages
func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
	wms := make([]wiremessage.WireMessage, len(batches))
	for _, cmd := range batches {
		wm, err := cmd.Encode(desc)
		if err != nil {
			return nil, err
		}

		wms = append(wms, wm)
	}

	return wms, nil
}

// Roundtrips the write batches, returning the result structs (as interface),
// the write batches that weren't round tripped and any errors
func roundTripBatches(
	ctx context.Context,
	desc description.SelectedServer,
	rw wiremessage.ReadWriter,
	batches []*WriteBatch,
	continueOnError bool,
	sess *session.Client,
	cmdKind WriteCommandKind,
) (interface{}, []*WriteBatch, error) {
	var res interface{}
	var upsertIndex int64 // the operation index for the upserted IDs map

	// hold onto txnNumber, reset it when loop exits to ensure reuse of same
	// transaction number if retry is needed
	var txnNumber int64
	if sess != nil && sess.RetryWrite {
		txnNumber = sess.TxnNumber
	}
	for j, cmd := range batches {
		rdr, err := cmd.RoundTrip(ctx, desc, rw)
		if err != nil {
			if sess != nil && sess.RetryWrite {
				sess.TxnNumber = txnNumber + int64(j)
			}
			return res, batches, err
		}

		// TODO can probably DRY up this code
		switch cmdKind {
		case InsertCommand:
			if res == nil {
				res = result.Insert{}
			}

			conv, _ := res.(result.Insert)
			insertCmd := &Insert{}
			r, err := insertCmd.decode(desc, rdr).Result()
			if err != nil {
				return res, batches, err
			}

			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)

			if r.WriteConcernError != nil {
				conv.WriteConcernError = r.WriteConcernError
				if sess != nil && sess.RetryWrite {
					sess.TxnNumber = txnNumber
					return conv, batches, nil // report writeconcernerror for retry
				}
			}

			conv.N += r.N

			if !continueOnError && len(conv.WriteErrors) > 0 {
				return conv, batches, nil
			}

			res = conv
		case UpdateCommand:
			if res == nil {
				res = result.Update{}
			}

			conv, _ := res.(result.Update)
			updateCmd := &Update{}
			r, err := updateCmd.decode(desc, rdr).Result()
			if err != nil {
				return conv, batches, err
			}

			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)

			if r.WriteConcernError != nil {
				conv.WriteConcernError = r.WriteConcernError
				if sess != nil && sess.RetryWrite {
					sess.TxnNumber = txnNumber
					return conv, batches, nil // report writeconcernerror for retry
				}
			}

			conv.MatchedCount += r.MatchedCount
			conv.ModifiedCount += r.ModifiedCount
			for _, upsert := range r.Upserted {
				conv.Upserted = append(conv.Upserted, result.Upsert{
					Index: upsert.Index + upsertIndex,
					ID:    upsert.ID,
				})
			}

			if !continueOnError && len(conv.WriteErrors) > 0 {
				return conv, batches, nil
			}

			res = conv
			upsertIndex += int64(cmd.numDocs)
		case DeleteCommand:
			if res == nil {
				res = result.Delete{}
			}

			conv, _ := res.(result.Delete)
			deleteCmd := &Delete{}
			r, err := deleteCmd.decode(desc, rdr).Result()
			if err != nil {
				return conv, batches, err
			}

			conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)

			if r.WriteConcernError != nil {
				conv.WriteConcernError = r.WriteConcernError
				if sess != nil && sess.RetryWrite {
					sess.TxnNumber = txnNumber
					return conv, batches, nil // report writeconcernerror for retry
				}
			}

			conv.N += r.N

			if !continueOnError && len(conv.WriteErrors) > 0 {
				return conv, batches, nil
			}

			res = conv
		}

		// Increment txnNumber for each batch
		if sess != nil && sess.RetryWrite {
			sess.IncrementTxnNumber()
			batches = batches[1:] // if batch encoded successfully, remove it from the slice
		}
	}

	if sess != nil && sess.RetryWrite {
		// if retryable write succeeded, transaction number will be incremented one extra time,
		// so we decrement it here
		sess.TxnNumber--
	}

	return res, batches, nil
}

// get the firstBatch, cursor ID, and namespace from a bson.Raw
func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
	cur, err := result.LookupErr("cursor")
	if err != nil {
		return nil, Namespace{}, 0, err
	}
	if cur.Type != bson.TypeEmbeddedDocument {
		return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
	}

	elems, err := cur.Document().Elements()
	if err != nil {
		return nil, Namespace{}, 0, err
	}

	var ok bool
	var arr bson.Raw
	var namespace Namespace
	var cursorID int64

	for _, elem := range elems {
		switch elem.Key() {
		case "firstBatch":
			arr, ok = elem.Value().ArrayOK()
			if !ok {
				return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
			}
			if err != nil {
				return nil, Namespace{}, 0, err
			}
		case "ns":
			if elem.Value().Type != bson.TypeString {
				return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
			}
			namespace = ParseNamespace(elem.Value().StringValue())
			err = namespace.Validate()
			if err != nil {
				return nil, Namespace{}, 0, err
			}
		case "id":
			cursorID, ok = elem.Value().Int64OK()
			if !ok {
				return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
			}
		}
	}

	vals, err := arr.Values()
	if err != nil {
		return nil, Namespace{}, 0, err
	}

	return vals, namespace, cursorID, nil
}

func getBatchSize(opts []bsonx.Elem) int32 {
	for _, opt := range opts {
		if opt.Key == "batchSize" {
			return opt.Value.Int32()
		}
	}

	return 0
}

// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
var ErrUnacknowledgedWrite = errors.New("unacknowledged write")

// WriteCommandKind is the type of command represented by a Write
type WriteCommandKind int8

// These constants represent the valid types of write commands.
const (
	InsertCommand WriteCommandKind = iota
	UpdateCommand
	DeleteCommand
)
