// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package topology

import (
	"context"
	"errors"
	"fmt"
	"math"
	"sync"
	"sync/atomic"
	"time"

	"github.com/mongodb/mongo-go-driver/event"
	"github.com/mongodb/mongo-go-driver/x/mongo/driver/auth"
	"github.com/mongodb/mongo-go-driver/x/network/address"
	"github.com/mongodb/mongo-go-driver/x/network/command"
	"github.com/mongodb/mongo-go-driver/x/network/connection"
	"github.com/mongodb/mongo-go-driver/x/network/description"
)

const minHeartbeatInterval = 500 * time.Millisecond
const connectionSemaphoreSize = math.MaxInt64

// ErrServerClosed occurs when an attempt to get a connection is made after
// the server has been closed.
var ErrServerClosed = errors.New("server is closed")

// ErrServerConnected occurs when at attempt to connect is made after a server
// has already been connected.
var ErrServerConnected = errors.New("server is connected")

// SelectedServer represents a specific server that was selected during server selection.
// It contains the kind of the typology it was selected from.
type SelectedServer struct {
	*Server

	Kind description.TopologyKind
}

// Description returns a description of the server as of the last heartbeat.
func (ss *SelectedServer) Description() description.SelectedServer {
	sdesc := ss.Server.Description()
	return description.SelectedServer{
		Server: sdesc,
		Kind:   ss.Kind,
	}
}

// These constants represent the connection states of a server.
const (
	disconnected int32 = iota
	disconnecting
	connected
	connecting
)

func connectionStateString(state int32) string {
	switch state {
	case 0:
		return "Disconnected"
	case 1:
		return "Disconnecting"
	case 2:
		return "Connected"
	case 3:
		return "Connecting"
	}

	return ""
}

// Server is a single server within a topology.
type Server struct {
	cfg     *serverConfig
	address address.Address

	connectionstate int32
	done            chan struct{}
	checkNow        chan struct{}
	closewg         sync.WaitGroup
	pool            connection.Pool

	desc atomic.Value // holds a description.Server

	averageRTTSet bool
	averageRTT    time.Duration

	subLock             sync.Mutex
	subscribers         map[uint64]chan description.Server
	currentSubscriberID uint64

	subscriptionsClosed bool
}

// ConnectServer creates a new Server and then initializes it using the
// Connect method.
func ConnectServer(ctx context.Context, addr address.Address, opts ...ServerOption) (*Server, error) {
	srvr, err := NewServer(addr, opts...)
	if err != nil {
		return nil, err
	}
	err = srvr.Connect(ctx)
	if err != nil {
		return nil, err
	}
	return srvr, nil
}

// NewServer creates a new server. The mongodb server at the address will be monitored
// on an internal monitoring goroutine.
func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
	cfg, err := newServerConfig(opts...)
	if err != nil {
		return nil, err
	}

	s := &Server{
		cfg:     cfg,
		address: addr,

		done:     make(chan struct{}),
		checkNow: make(chan struct{}, 1),

		subscribers: make(map[uint64]chan description.Server),
	}
	s.desc.Store(description.Server{Addr: addr})

	var maxConns uint64
	if cfg.maxConns == 0 {
		maxConns = math.MaxInt64
	} else {
		maxConns = uint64(cfg.maxConns)
	}

	s.pool, err = connection.NewPool(addr, uint64(cfg.maxIdleConns), maxConns, cfg.connectionOpts...)
	if err != nil {
		return nil, err
	}

	return s, nil
}

// Connect initialzies the Server by starting background monitoring goroutines.
// This method must be called before a Server can be used.
func (s *Server) Connect(ctx context.Context) error {
	if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
		return ErrServerConnected
	}
	s.desc.Store(description.Server{Addr: s.address})
	go s.update()
	s.closewg.Add(1)
	return s.pool.Connect(ctx)
}

// Disconnect closes sockets to the server referenced by this Server.
// Subscriptions to this Server will be closed. Disconnect will shutdown
// any monitoring goroutines, close the idle connection pool, and will
// wait until all the in use connections have been returned to the connection
// pool and are closed before returning. If the context expires via
// cancellation, deadline, or timeout before the in use connections have been
// returned, the in use connections will be closed, resulting in the failure of
// any in flight read or write operations. If this method returns with no
// errors, all connections associated with this Server have been closed.
func (s *Server) Disconnect(ctx context.Context) error {
	if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
		return ErrServerClosed
	}

	// For every call to Connect there must be at least 1 goroutine that is
	// waiting on the done channel.
	s.done <- struct{}{}
	err := s.pool.Disconnect(ctx)
	if err != nil {
		return err
	}

	s.closewg.Wait()
	atomic.StoreInt32(&s.connectionstate, disconnected)

	return nil
}

// Connection gets a connection to the server.
func (s *Server) Connection(ctx context.Context) (connection.Connection, error) {
	if atomic.LoadInt32(&s.connectionstate) != connected {
		return nil, ErrServerClosed
	}
	conn, desc, err := s.pool.Get(ctx)
	if err != nil {
		if _, ok := err.(*auth.Error); ok {
			// authentication error --> drain connection
			_ = s.pool.Drain()
		}
		if _, ok := err.(*connection.NetworkError); ok {
			// update description to unknown and clears the connection pool
			if desc != nil {
				desc.Kind = description.Unknown
				desc.LastError = err
				s.updateDescription(*desc, false)
			} else {
				_ = s.pool.Drain()
			}
		}
		return nil, err
	}
	if desc != nil {
		go s.updateDescription(*desc, false)
	}
	sc := &sconn{Connection: conn, s: s}
	return sc, nil
}

// Description returns a description of the server as of the last heartbeat.
func (s *Server) Description() description.Server {
	return s.desc.Load().(description.Server)
}

// SelectedDescription returns a description.SelectedServer with a Kind of
// Single. This can be used when performing tasks like monitoring a batch
// of servers and you want to run one off commands against those servers.
func (s *Server) SelectedDescription() description.SelectedServer {
	sdesc := s.Description()
	return description.SelectedServer{
		Server: sdesc,
		Kind:   description.Single,
	}
}

// Subscribe returns a ServerSubscription which has a channel on which all
// updated server descriptions will be sent. The channel will have a buffer
// size of one, and will be pre-populated with the current description.
func (s *Server) Subscribe() (*ServerSubscription, error) {
	if atomic.LoadInt32(&s.connectionstate) != connected {
		return nil, ErrSubscribeAfterClosed
	}
	ch := make(chan description.Server, 1)
	ch <- s.desc.Load().(description.Server)

	s.subLock.Lock()
	defer s.subLock.Unlock()
	if s.subscriptionsClosed {
		return nil, ErrSubscribeAfterClosed
	}
	id := s.currentSubscriberID
	s.subscribers[id] = ch
	s.currentSubscriberID++

	ss := &ServerSubscription{
		C:  ch,
		s:  s,
		id: id,
	}

	return ss, nil
}

// RequestImmediateCheck will cause the server to send a heartbeat immediately
// instead of waiting for the heartbeat timeout.
func (s *Server) RequestImmediateCheck() {
	select {
	case s.checkNow <- struct{}{}:
	default:
	}
}

// update handles performing heartbeats and updating any subscribers of the
// newest description.Server retrieved.
func (s *Server) update() {
	defer s.closewg.Done()
	heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
	rateLimiter := time.NewTicker(minHeartbeatInterval)
	defer heartbeatTicker.Stop()
	defer rateLimiter.Stop()
	checkNow := s.checkNow
	done := s.done

	var doneOnce bool
	defer func() {
		if r := recover(); r != nil {
			if doneOnce {
				return
			}
			// We keep this goroutine alive attempting to read from the done channel.
			<-done
		}
	}()

	var conn connection.Connection
	var desc description.Server

	desc, conn = s.heartbeat(nil)
	s.updateDescription(desc, true)

	closeServer := func() {
		doneOnce = true
		s.subLock.Lock()
		for id, c := range s.subscribers {
			close(c)
			delete(s.subscribers, id)
		}
		s.subscriptionsClosed = true
		s.subLock.Unlock()
		if conn == nil {
			return
		}
		conn.Close()
	}
	for {
		select {
		case <-heartbeatTicker.C:
		case <-checkNow:
		case <-done:
			closeServer()
			return
		}

		select {
		case <-rateLimiter.C:
		case <-done:
			closeServer()
			return
		}

		desc, conn = s.heartbeat(conn)
		s.updateDescription(desc, false)
	}
}

// updateDescription handles updating the description on the Server, notifying
// subscribers, and potentially draining the connection pool. The initial
// parameter is used to determine if this is the first description from the
// server.
func (s *Server) updateDescription(desc description.Server, initial bool) {
	defer func() {
		//  ¯\_(ツ)_/¯
		_ = recover()
	}()
	s.desc.Store(desc)

	s.subLock.Lock()
	for _, c := range s.subscribers {
		select {
		// drain the channel if it isn't empty
		case <-c:
		default:
		}
		c <- desc
	}
	s.subLock.Unlock()

	if initial {
		// We don't clear the pool on the first update on the description.
		return
	}

	switch desc.Kind {
	case description.Unknown:
		_ = s.pool.Drain()
	}
}

// heartbeat sends a heartbeat to the server using the given connection. The connection can be nil.
func (s *Server) heartbeat(conn connection.Connection) (description.Server, connection.Connection) {
	const maxRetry = 2
	var saved error
	var desc description.Server
	var set bool
	var err error
	ctx := context.Background()

	for i := 1; i <= maxRetry; i++ {
		if conn != nil && conn.Expired() {
			conn.Close()
			conn = nil
		}

		if conn == nil {
			opts := []connection.Option{
				connection.WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
				connection.WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
				connection.WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
			}
			opts = append(opts, s.cfg.connectionOpts...)
			// We override whatever handshaker is currently attached to the options with an empty
			// one because need to make sure we don't do auth.
			opts = append(opts, connection.WithHandshaker(func(h connection.Handshaker) connection.Handshaker {
				return nil
			}))

			// Override any command monitors specified in options with nil to avoid monitoring heartbeats.
			opts = append(opts, connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
				return nil
			}))
			conn, _, err = connection.New(ctx, s.address, opts...)
			if err != nil {
				saved = err
				if conn != nil {
					conn.Close()
				}
				conn = nil
				continue
			}
		}

		now := time.Now()

		isMasterCmd := &command.IsMaster{Compressors: s.cfg.compressionOpts}
		isMaster, err := isMasterCmd.RoundTrip(ctx, conn)
		if err != nil {
			saved = err
			conn.Close()
			conn = nil
			continue
		}

		clusterTime := isMaster.ClusterTime
		if s.cfg.clock != nil {
			s.cfg.clock.AdvanceClusterTime(clusterTime)
		}

		delay := time.Since(now)
		desc = description.NewServer(s.address, isMaster).SetAverageRTT(s.updateAverageRTT(delay))
		desc.HeartbeatInterval = s.cfg.heartbeatInterval
		set = true

		break
	}

	if !set {
		desc = description.Server{
			Addr:      s.address,
			LastError: saved,
		}
	}

	return desc, conn
}

func (s *Server) updateAverageRTT(delay time.Duration) time.Duration {
	if !s.averageRTTSet {
		s.averageRTT = delay
	} else {
		alpha := 0.2
		s.averageRTT = time.Duration(alpha*float64(delay) + (1-alpha)*float64(s.averageRTT))
	}
	return s.averageRTT
}

// Drain will drain the connection pool of this server. This is mainly here so the
// pool for the server doesn't need to be directly exposed and so that when an error
// is returned from reading or writing, a client can drain the pool for this server.
// This is exposed here so we don't have to wrap the Connection type and sniff responses
// for errors that would cause the pool to be drained, which can in turn centralize the
// logic for handling errors in the Client type.
func (s *Server) Drain() error { return s.pool.Drain() }

// String implements the Stringer interface.
func (s *Server) String() string {
	desc := s.Description()
	str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
		s.address, desc.Kind, connectionStateString(s.connectionstate))
	if len(desc.Tags) != 0 {
		str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
	}
	if s.connectionstate == connected {
		str += fmt.Sprintf(", Avergage RTT: %d", s.averageRTT)
	}
	if desc.LastError != nil {
		str += fmt.Sprintf(", Last error: %s", desc.LastError)
	}

	return str
}

// ServerSubscription represents a subscription to the description.Server updates for
// a specific server.
type ServerSubscription struct {
	C  <-chan description.Server
	s  *Server
	id uint64
}

// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
// subscription channel.
func (ss *ServerSubscription) Unsubscribe() error {
	ss.s.subLock.Lock()
	defer ss.s.subLock.Unlock()
	if ss.s.subscriptionsClosed {
		return nil
	}

	ch, ok := ss.s.subscribers[ss.id]
	if !ok {
		return nil
	}

	close(ch)
	delete(ss.s.subscribers, ss.id)

	return nil
}
