#!/usr/bin/env python3

# Copyright 2020-present Open Networking Foundation
#
# SPDX-License-Identifier: LicenseRef-ONF-Member-Only-1.0

import sys
import os
import json
import logging
import enum
import pycurl
import time
import serial
import subprocess
import time
from collections import namedtuple
from statistics import median
import xml.etree.ElementTree as ET

'''
"Simple" script that checks Aether network operational status periodically
by controlling the attached 4G/LTE modem with AT commands and
report the result to the central monitoring server.
'''

# Parse config with backwards compatibility with config.json pre 0.6.6
config_file_contents = open(os.getenv('CONFIG_FILE', "./config.json")).read()
config_file_contents = config_file_contents.replace("user_plane_ping_test", "dns")
config_file_contents = config_file_contents.replace("speedtest_iperf", "iperf_server")
config_file_contents = config_file_contents.replace("\"speedtest_ping_dns\": \"1.1.1.1\",", "")
CONF = json.loads(
    config_file_contents, object_hook=lambda d: namedtuple('X', d.keys())(*d.values())
)

logging.basicConfig(
    filename=CONF.log_file,
    format='%(asctime)s [%(levelname)s] %(message)s',
    level=logging.getLevelName(CONF.log_level)
)


class State(enum.Enum):
    error = "-1"
    disconnected = "0"
    connected = "1"

    @classmethod
    def has_value(cls, value):
        return value in cls._value2member_map_


class Modem():
    log = logging.getLogger('aether_edge_monitoring.Modem')

    read_timeout = 0.1

    def __init__(self, port, baudrate):
        self.port = port
        self.baudrate = baudrate
        self._response = None

    def get_modem_port(self):
        cmd = "ls " + CONF.modem.port
        sp = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE, universal_newlines=True)
        rc = sp.wait()
        ret,err = sp.communicate()
        if err != "" :
            self.log.error("unable to find serial port " + err)

        ret = ret.replace(CONF.modem.port,"").strip()
        self.log.info("Modem.get_modem_port found " + ret)
        return ret

    def connect(self):
        self.port=self.get_modem_port()
        self.log.info("modem.connect Port: %s, BaudRate: %i",self.port,self.baudrate)
        self.serial = serial.Serial(
            port=self.port,
            baudrate=self.baudrate,
            timeout=1)

    def _write(self, command):
        if self.serial.inWaiting() > 0:
            self.serial.flushInput()

        self._response = b""

        self.serial.write(bytearray(command + "\r", "ascii"))
        read = self.serial.inWaiting()
        while True:
            self.log.debug("Waiting for write to complete...")
            if read > 0:
                self._response += self.serial.read(read)
            else:
                time.sleep(self.read_timeout)
            read = self.serial.inWaiting()
            if read == 0:
                break
        self.log.debug("Write complete...")
        return self._response.decode("ascii").replace('\r\n', ' ')

    def write(self, command, wait_resp=True):
        response = self._write(command)
        self.log.debug("%s: %s", command, response)

        if wait_resp and "ERROR" in response:
            return False, None
        return True, response

    def get_state(self):
        success, result = self.write('AT+CGATT?')
        if not success or 'CGATT:' not in result:
            return State.error
        state = result.split('CGATT:')[1].split(' ')[0]
        return State(state)

    def close(self):
        self.serial.close()


def get_control_plane_state(modem, dongle_stats=None):
    if dongle_stats and dongle_stats['Connection'] == 'Connected':
        return State.connected

    # Disable radio fuction
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=0' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=0' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=0' failed")
        return State.error

    # Wait until the modem is fully disconnected
    retry = 0
    state = None
    while retry < CONF.detach_timeout:
        state = modem.get_state()
        if state is State.disconnected:
            break
        time.sleep(1)
        retry += 1

    if state is not State.disconnected:
        logging.error("Failed to disconnect")
        return State.error

    time.sleep(2)
    # Enable radio function
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=1' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=1' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=1' failed")
        return State.error

    # Wait attach_timeout sec for the modem to be fully connected
    retry = 0
    while retry < CONF.attach_timeout:
        state = modem.get_state()
        if state is State.connected:
            break
        time.sleep(1)
        retry += 1
    # CGATT sometimes returns None
    if state is State.error:
        state = State.disconnected

    return state


def get_user_plane_state(modem):
    if "dry_run" in CONF.ips._fields and CONF.ips.dry_run: # run dry_run latency test as user plane test
        dry_run_latency, dry_run_passed = run_ping_test(CONF.ips.dry_run, 10)
        if dry_run_passed:
            return State.connected, dry_run_latency
        else:
            logging.warning("User plane test failed")
            return State.disconnected, dry_run_latency
    else: # run default user plane test
        try:
            subprocess.check_output(
                "ping -c 3 " + CONF.ips.dns + ">/dev/null 2>&1",
                shell=True)
            return State.connected, None
        except subprocess.CalledProcessError as e:
            logging.warning("User plane test failed")
            return State.disconnected, None


def run_ping_test(ip, count):
    '''
    Runs the ping test
    Input: IP to ping, # times to ping
    Returns: Transmitted packets
             Received packets
             Median ping ms
             Min ping ms
             Avg ping ms
             Max ping ms
             Std Dev ping ms
    '''
    result = {'transmitted': 0,
              'received': 0,
              'median': 0.0,
              'min': 0.0,
              'avg': 0.0,
              'max': 0.0,
              'stddev': 0.0}
    if not ip:
        return result, True
    try:
        pingOutput = subprocess.check_output(
                    "ping -c " + str(count) + " " + \
                    ip, shell=True).decode("UTF-8").split()
        result['transmitted'] = int(pingOutput[-15])
        result['received'] = int(pingOutput[-12])
        if result['received'] > 0:
            pingValues = []

            # Hack for getting all ping values for median
            for word in pingOutput:
                if "time=" in word:
                    pingValues.append(float(word.split("=")[1]))
            result['median'] = round(median(pingValues), 3)

            pingResult = pingOutput[-2].split('/')
            result['min'] = float(pingResult[0])
            result['avg'] = float(pingResult[1])
            result['max'] = float(pingResult[2])
            result['stddev'] = float(pingResult[3])
        else:
            logging.error("No packets received during ping " + ip)
            return result, False
    except Exception as e:
        logging.error("Ping test failed for " + ip + ": %s", e)
        return result, False
    return result, True


def get_ping_test(modem, dry_run_latency=None):
    '''
    Prepares the ping test.
    Runs ping tests from 'ips' entry in config.json in order.
    Note: 'dry_run' is not run here; it is run during the user plane test.
    '''
    speedtest_ping = {}
    status = True
    ping_test_passed = True

    if dry_run_latency:
        speedtest_ping["dry_run"] = dry_run_latency

    for i in range(0, len(CONF.ips)):
        if CONF.ips._fields[i] == "dry_run":
            continue
        count = 10
        speedtest_ping[CONF.ips._fields[i]], status = run_ping_test(CONF.ips[i], count)
        if not status:
            ping_test_passed = False
            logging.error("Ping test failed. Not running further tests.")
    return speedtest_ping, ping_test_passed

def run_iperf_test(ip, port, time_duration, is_downlink):
    '''
    Runs iperf test to specified IP in the config file.
    - Runs for 10 seconds (10 iterations)
    - Retrieves downlink and uplink test results from json output
    '''
    result = 0.0
    if not ip or port == 0:
        return result
    maxRetries = 2
    err = None
    for _ in range(0, maxRetries):
        try:
            iperfResult = json.loads(subprocess.check_output(
                    "iperf3 -c " + ip +
                    " -p " + str(port) +
                    " -t " + str(time_duration) +
                    (" -R " if is_downlink else "") +
                    " --json", shell=True).decode("UTF-8"))
            received_mbps = iperfResult['end']['sum_received']['bits_per_second'] / 1000000
            sent_mbps = iperfResult['end']['sum_sent']['bits_per_second'] / 1000000.0
            result = received_mbps if is_downlink else sent_mbps
            return result
        except Exception as e:
            err = e
            time.sleep(5)
            pass
    logging.error("After " + str(maxRetries) + " retries, iperf test failed for " + ip + ": %s", err)
    return result


def get_iperf_test(modem):
    '''
    Prepares the iperf test.
    '''
    global hour_iperf_scheduled_time_last_ran
    speedtest_iperf = {}
    speedtest_iperf['cluster'] = {}

    if "iperf_schedule" in CONF._fields and len(CONF.iperf_schedule) > 0:
        if int(time.strftime("%H")) not in CONF.iperf_schedule: # not in the schedule
            hour_iperf_scheduled_time_last_ran = -1
            return None
        elif int(time.strftime("%H")) == hour_iperf_scheduled_time_last_ran: # already ran this hour
            return None
    hour_iperf_scheduled_time_last_ran = int(time.strftime("%H"))

    speedtest_iperf['cluster']['downlink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, True)
    speedtest_iperf['cluster']['uplink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, False)

    return speedtest_iperf


def get_signal_quality(modem, dongle_stats=None):
    if dongle_stats and dongle_stats['RSRQ'] != '' and dongle_stats['RSRP'] != '':
        rsrq = int((float(dongle_stats['RSRQ']) + 19.5) * 2)
        rsrp = int(float(dongle_stats['RSRP']) + 140)
        return {'rsrq': rsrq, 'rsrp': rsrp}

    success, result = modem.write('AT+CESQ')
    logging.debug("get_signal_quality success %i result %s",success,result)
    if not success or 'CESQ: ' not in result:
        logging.error("Failed to get signal quality")
        return {'rsrq':0, 'rsrp':0}

    logging.debug("%s", result)
    tmp_rsrq = result.split('CESQ:')[1].split(',')[4]
    tmp_rsrp = result.split('CESQ:')[1].split(',')[5]

    rsrq = int(tmp_rsrq.strip())
    rsrp = int(tmp_rsrp.strip().split(' ')[0])
    result = {
        'rsrq': 0 if rsrq is 255 else rsrq,
        'rsrp': 0 if rsrp is 255 else rsrp
    }

    return result


def get_dongle_stats():
    result = {'SuccessfulFetch' : False}
    if "report_in_band" in CONF._fields:
        result['inBandReporting'] = CONF.report_in_band
    else:
        result['inBandReporting'] = False
    XMLkeys = ["MAC",
               "PLMNStatus",
               "UICCStatus",
               "IMEI",
               "IMSI",
               "PLMNSelected",
               "MCC",
               "MNC",
               "PhyCellID",
               "CellGlobalID",
               "Band",
               "EARFCN",
               "BandWidth",
               "RSRP",
               "RSRQ",
               "ServCellState",
               "Connection",
               "IPv4Addr"]
    dongleStatsXML = None
    try:
        dongleStatsXML = ET.fromstring(subprocess.check_output("curl -u admin:admin -s 'http://192.168.0.1:8080/cgi-bin/ltestatus.cgi?Command=Status'", shell=True).decode("UTF-8"))
    except Exception as e:
        logging.error("Failed to fetch dongle stats from URL: " + str(e))
        return result
    try:
        for key in XMLkeys:
            try:
                result[key] = dongleStatsXML.find(key).text
            except AttributeError as e:
                logging.warn("Failed to find " + key + " in XML.")
                result[key] = ""
        result["SuccessfulFetch"] = True
    except Exception as e:
        logging.error("Failed to fetch dongle stats from XML: " + str(e))
        return result
    return result


def report_status(signal_quality, dongle_stats, cp_state=None, up_state=None, speedtest_ping=None, speedtest_iperf=None):
    report = {
        'name': CONF.edge_name,
        'status': {
            'control_plane': "disconnected",
            'user_plane': "disconnected"
        },
        'dongle_stats': {
            'SuccessfulFetch' : False
        },
        'speedtest': {
            'ping': {
                'dns': {
                    'transmitted' : 0,
                    'received' : 0,
                    'median' : 0.0,
                    'min': 0.0,
                    'avg': 0.0,
                    'max': 0.0,
                    'stddev': 0.0
                }
            },
            'iperf': {
                'cluster': {
                    'downlink': 0.0,
                    'uplink': 0.0
                }
            }
        },
        'signal_quality': {
            'rsrq': 0,
            'rsrp': 0
        }
    }

    if cp_state is not None:
        report['status']['control_plane'] = cp_state.name
    if up_state is not None:
        report['status']['user_plane'] = up_state.name
    if speedtest_ping is not None:
        report['speedtest']['ping'] = speedtest_ping
    if speedtest_iperf is not None:
        report['speedtest']['iperf'] = speedtest_iperf
    report['signal_quality'] = signal_quality
    report['dongle_stats'] = dongle_stats

    logging.info("Sending report %s", report)
    global cycles
    cycles += 1
    logging.info("Number of cycles since modem restart %i",cycles)

    try:
        interface = None
        report_via_modem = "report_in_band" in CONF._fields and CONF.report_in_band and \
            "iface" in CONF.modem._fields and CONF.modem.iface
        report_via_given_iface = "report_iface" in CONF._fields and CONF.report_iface

        c = pycurl.Curl()
        c.setopt(pycurl.URL, CONF.report_url)
        c.setopt(pycurl.POST, True)
        c.setopt(pycurl.HTTPHEADER, ['Content-Type: application/json'])
        c.setopt(pycurl.TIMEOUT, 10)
        c.setopt(pycurl.POSTFIELDS, json.dumps(report))
        c.setopt(pycurl.WRITEFUNCTION, lambda x: None) # don't output to console

        if report_via_modem: # report in-band
            interface = CONF.modem.iface
            c.setopt(pycurl.INTERFACE, interface)
        elif report_via_given_iface: # report over given interface
            interface = CONF.report_iface
            c.setopt(pycurl.INTERFACE, interface)
        # else, reports over default interface

        try:
            c.perform()
            logging.info("Report sent via " + interface + "!")
        except Exception as e:
            if report_via_modem and report_via_given_iface:
                logging.warning("Sending report via modem failed. Attempting to send report via " + str(CONF.report_iface) + ".")
                interface = CONF.report_iface
                c.setopt(pycurl.INTERFACE, interface)
                c.perform()
                logging.info("Report sent via " + interface + "!")
            else:
                logging.error("Failed to send report: " + str(e))
        c.close()
    except Exception as e:
        logging.error("Failed to send report: " + str(e))
        c.close()

    time.sleep(CONF.report_interval)

def reset_usb():
    try:
        # Attempt to run uhubctl
        if (int(subprocess.call("which uhubctl",shell=True)) == 0):
            cmd = "/usr/sbin/uhubctl -a 0 -l 2" # -a 0 = action is shutdown -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Shutting down usb hub 2 results %s" , ret)
            time.sleep(10)# let power down process settle out
            cmd = "/usr/sbin/uhubctl -a 1 -l 2" # -a 1 = action is start -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Starting up usb hub 2 results %s" , ret)
            time.sleep(10) #allow dbus to finish
            global cycles
            cycles = 0
        else:
            reboot(120)
    except Exception as e:
        logging.error("Failed to run uhubctl: %s", e)
        reboot(120)

def reboot(delay):
    logging.error("Failed to run uhubctl. Reboot system in " + str(delay) + " second(s).")
    time.sleep(delay)
    subprocess.check_output("sudo shutdown -r now", shell=True)

def main():
    global cycles
    global hour_iperf_scheduled_time_last_ran
    cycles = 0
    hour_iperf_scheduled_time_last_ran = -1

    try:
        if "report_in_band" in CONF._fields and \
        "iface" in CONF.modem._fields and CONF.modem.iface:
            if CONF.report_in_band: # need to add default gateway if reporting in-band
                subprocess.check_output("sudo route add default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
            else:
                subprocess.check_output("sudo route del default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
    except Exception as e:
        logging.error("Failed to change default route for modem: " + str(e))

    for ip in CONF.ips:
        if not ip:
            continue
        try:
            subprocess.check_output("sudo ip route replace {}/32 via {}".format(
                ip, CONF.modem.ip_addr), shell=True)
        except subprocess.CalledProcessError as e:
            logging.error("Failed to add routes: " + str(e.returncode) + str(e.output))
            time.sleep(10) # Sleep for 10 seconds before retry
            sys.exit(1)

    modem = Modem(CONF.modem.port, CONF.modem.baud)
    try:
        modem.connect()
    except serial.serialutil.SerialException as e:
        logging.error("Failed to connect the modem for %s", e)
        sys.exit(1)

    while True:
        dongle_stats = get_dongle_stats()
        cp_state = get_control_plane_state(modem, dongle_stats)
        if cp_state != State.connected:
            logging.error("Control plane not connected")
            continue
        signal_quality = get_signal_quality(modem, dongle_stats)
        if cp_state is State.error:
            logging.error("Modem is in error state.")
            reset_usb()
            sys.exit(1)
        if cp_state is State.disconnected:
            # Failed to attach, don't need to run other tests
            logging.error("Check control plane - fail")
            report_status(signal_quality, dongle_stats)
            continue

        up_state, dry_run_latency = get_user_plane_state(modem)
        if up_state is State.disconnected:
            logging.error("Check user plane - fail")
            # Basic user plane test failed, don't need to run the rest of tests
            report_status(signal_quality, dongle_stats, cp_state)
            continue

        speedtest_ping, speedtest_status = get_ping_test(modem, dry_run_latency)
        if speedtest_status:
            speedtest_iperf = get_iperf_test(modem)
        else:
            logging.error("Check iperf check - fail")
            report_status(signal_quality, dongle_stats, cp_state, up_state, speedtest_ping)
            continue

        report_status(signal_quality, dongle_stats, cp_state, up_state, speedtest_ping, speedtest_iperf)

    modem.close()


if __name__ == "__main__":
    main()
