
# Copyright 2017-present Open Networking Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Connection test cases

"""

import time
import sys
import logging

import unittest
import random

from oftest import config
import oftest.controller as controller
import ofp
import oftest.dataplane as dataplane

from oftest.testutils import *

@disabled
class BaseHandshake(unittest.TestCase):
    """
    Base handshake case to set up controller, but do not send hello.
    """

    def controllerSetup(self, host, port):
        con = controller.Controller(host=host,port=port)

        # clean_shutdown should be set to False to force quit app
        self.clean_shutdown = True
        # disable initial hello so hello is under control of test
        con.initial_hello = False

        con.start()
        self.controllers.append(con)

    def setUp(self):
        logging.info("** START TEST CASE " + str(self))

        self.controllers = []
        self.default_timeout = test_param_get('default_timeout',
                                              default=2)

    def tearDown(self):
        logging.info("** END TEST CASE " + str(self))
        for con in self.controllers:
            con.shutdown()
            if self.clean_shutdown:
                con.join()

    def runTest(self):
        # do nothing in the base case
        pass

    def assertTrue(self, cond, msg):
        if not cond:
            logging.error("** FAILED ASSERTION: " + msg)
        unittest.TestCase.assertTrue(self, cond, msg)

class HandshakeNoHello(BaseHandshake):
    """
    TCP connect to switch, but do not sent hello,
    and wait for disconnect.
    """
    def runTest(self):
        self.controllerSetup(config["controller_host"],
                             config["controller_port"])
        self.controllers[0].connect(self.default_timeout)

        logging.info("TCP Connected " + 
                     str(self.controllers[0].switch_addr))
        logging.info("Hello not sent, waiting for timeout")

        # wait for controller to die
        self.assertTrue(self.controllers[0].wait_disconnected(timeout=10),
                        "Not notified of controller disconnect")

class HandshakeNoFeaturesRequest(BaseHandshake):
    """
    TCP connect to switch, send hello, but do not send features request,
    and wait for disconnect.
    """
    def runTest(self):
        self.controllerSetup(config["controller_host"],
                             config["controller_port"])
        self.controllers[0].connect(self.default_timeout)

        logging.info("TCP Connected " + 
                     str(self.controllers[0].switch_addr))
        logging.info("Sending hello")
        self.controllers[0].message_send(ofp.message.hello())

        logging.info("Features request not sent, waiting for timeout")

        # wait for controller to die
        self.assertTrue(self.controllers[0].wait_disconnected(timeout=10),
                        "Not notified of controller disconnect")

@disabled
class CompleteHandshake(BaseHandshake):
    """
    Set up multiple controllers and complete handshake, but otherwise do nothing.
    """

    def buildControllerList(self):                                             
        # controller_list is a list of IP:port tuples
        con_list = test_param_get('controller_list')
        if con_list is not None:
            self.controller_list = []
            for controller in con_list:
                ip,portstr = controller.split(':')
                try:
                    port = int(portstr)
                except:
                    self.assertTrue(0, "failure converting port " +
                                    portstr + " to integer")
                self.controller_list.append( (ip, int(port)) )
        else:
            self.controller_list = [(config["controller_host"],
                                     config["controller_port"])]

    def __init__(self, keep_alive=True, cxn_cycles=5,
                 controller_timeout=-1, hello_timeout=5, 
                 features_req_timeout=5, disconnected_timeout=3,
                 report_pkts=False):
        BaseHandshake.__init__(self)
        self.buildControllerList()
        self.keep_alive = keep_alive
        self.cxn_cycles = test_param_get('cxn_cycles') \
            or cxn_cycles
        self.controller_timeout = test_param_get('controller_timeout') \
            or controller_timeout
        self.hello_timeout = test_param_get('hello_timeout') \
            or hello_timeout
        self.features_req_timeout = test_param_get('features_req_timeout') \
            or features_req_timeout
        self.disconnected_timeout = test_param_get('disconnected_timeout') \
            or disconnected_timeout
        self.report_pkts = report_pkts

    # These functions provide per-tick processing
    def periodic_task_init(self, tick_time):
        """
        Assumes tick_time is in seconds, usually 1/10 of a sec
        """
        if not self.report_pkts:
            return
        self.start_time = time.time()
        self.last_report = self.start_time
        self.pkt_in_count = 0 # Total packet in count
        self.periodic_pkt_in_count = 0 # Packet-ins this cycle

    def periodic_task_tick(self, con):
        """
        Process one tick.  Currently this just counts pkt-in msgs
        """
        if not self.report_pkts:
            return
        if con.cstate != 4:
            return

        # Gather packets from control cxn
        current_time = time.time()
        new_pkts = con.packet_in_count - self.pkt_in_count
        self.pkt_in_count = con.packet_in_count
        self.periodic_pkt_in_count += new_pkts
        con.clear_queue()

        # Report every second or so
        if (current_time - self.last_report >= 1):
            if self.periodic_pkt_in_count:
                print "%7.2f: pkt/sec last period:  %6d.  Total %10d." % (
                    current_time - self.start_time,
                    self.periodic_pkt_in_count/(current_time - self.last_report),
                    self.pkt_in_count)
            self.last_report = current_time
            self.periodic_pkt_in_count = 0

    def periodic_task_done(self):
        if not self.report_pkts:
            return
        print "Received %d pkt-ins over %d seconds" % (
            self.pkt_in_count, time.time() - self.start_time)
        
    def runTest(self):
        for conspec in self.controller_list:
            self.controllerSetup(conspec[0], conspec[1])
        for i in range(len(self.controller_list)):
            self.controllers[i].cstate = 0
            self.controllers[i].keep_alive = self.keep_alive
            self.controllers[i].saved_switch_addr = None
        tick = 0.1  # time period in seconds at which controllers are handled
        self.periodic_task_init(tick)

        disconnected_count = 0
        cycle = 0
        while True:
            states = []
            for con in self.controllers:
                condesc = con.host + ":" + str(con.port) + ": "
                logging.debug("Checking " + condesc)

                if con.switch_socket:
                    if con.switch_addr != con.saved_switch_addr:
                        con.saved_switch_addr = con.switch_addr
                        con.cstate = 0

                    if con.cstate == 0:
                        logging.info(condesc + "Sending hello to " +
                                     str(con.switch_addr))
                        con.message_send(ofp.message.hello())
                        con.cstate = 1
                        con.count = 0
                    elif con.cstate == 1:
                        reply, pkt = con.poll(exp_msg=ofp.OFPT_HELLO,
                                              timeout=0)
                        if reply is not None:
                            logging.info(condesc + 
                                         "Hello received from " +
                                         str(con.switch_addr))
                            con.cstate = 2
                        else:
                            con.count = con.count + 1
                            # fall back to previous state on timeout
                            if con.count >= self.hello_timeout/tick:
                                logging.info(condesc + 
                                             "Timeout hello from " +
                                             str(con.switch_addr))
                                con.cstate = 0
                    elif con.cstate == 2:
                        logging.info(condesc + "Sending features request to " +
                                     str(con.switch_addr))
                        con.message_send(ofp.message.features_request())
                        con.cstate = 3
                        con.count = 0
                    elif con.cstate == 3:
                        reply, pkt = con.poll(exp_msg=ofp.OFPT_FEATURES_REPLY,
                                              timeout=0)
                        if reply is not None:
                            logging.info(condesc + 
                                         "Features reply received from " +
                                         str(con.switch_addr))
                            con.cstate = 4
                            con.count = 0
                            cycle = cycle + 1
                            logging.info("Cycle " + str(cycle))
                        else:
                            con.count = con.count + 1
                            # fall back to previous state on timeout
                            if con.count >= self.features_req_timeout/tick:
                                logging.info(condesc +
                                             "Timeout features request from " +
                                             str(con.switch_addr))
                                con.cstate = 2
                    elif con.cstate == 4:
                        if (self.controller_timeout < 0 or
                            con.count < self.controller_timeout/tick):
                            logging.debug(condesc +
                                          "Maintaining connection to " +
                                          str(con.switch_addr))
                            con.count = con.count + 1
                        else:
                            logging.info(condesc + 
                                         "Disconnecting from " +
                                         str(con.switch_addr))
                            con.disconnect()
                            con.cstate = 0
                else:
                    con.cstate = 0
            
                states.append(con.cstate)
                self.periodic_task_tick(con)

            logging.debug("Cycle " + str(cycle) +
                          ", states " + str(states) +
                          ", disconnected_count " + str(disconnected_count))
            if 4 in states:
                disconnected_count = 0
            else:
                disconnected_count = disconnected_count + 1
            if cycle != 0:
                self.assertTrue(disconnected_count < self.disconnected_timeout/tick,
                                "Timeout expired connecting to controller")
            else:
               # on first cycle, allow more time for initial connect
               self.assertTrue(disconnected_count < 2*self.disconnected_timeout/tick,
                               "Timeout expired connecting to controller on init")

            if cycle > self.cxn_cycles:
               break
            time.sleep(tick)
        self.periodic_task_done()

@disabled
class HandshakeAndKeepalive(CompleteHandshake):
    """
    Complete handshake and respond to echo request, but otherwise do nothing.
    Good for manual testing.
    """

    def __init__(self):
       CompleteHandshake.__init__(self, keep_alive=True)

@disabled
class MonitorPacketIn(CompleteHandshake):
    """
    Complete handshake and respond to echo request.  As packet-in messages
    arrive, report the count and pkts/second
    """

    def __init__(self):
       CompleteHandshake.__init__(self, keep_alive=True, report_pkts=True)

@disabled
class HandshakeNoEcho(CompleteHandshake):
    """
    Complete handshake, but otherwise do nothing, and do not respond to echo.
    """

    def __init__(self):
       CompleteHandshake.__init__(self, keep_alive=False)

@disabled
class HandshakeAndDrop(CompleteHandshake):
    """
    Complete handshake, but otherwise do nothing, and drop connection after a while.
    """

    def __init__(self):
       CompleteHandshake.__init__(self, keep_alive=True, controller_timeout=10)

