#!/usr/bin/env python3

# Copyright 2020-present Open Networking Foundation
#
# SPDX-License-Identifier: LicenseRef-ONF-Member-Only-1.0

import sys
import os
import json
import logging
import enum
import pycurl
import time
import serial
import subprocess
import time
from collections import namedtuple
from statistics import median
import xml.etree.ElementTree as ET

'''
"Simple" script that checks Aether network operational status periodically
by controlling the attached 4G/LTE modem with AT commands and
report the result to the central monitoring server.
'''

# Parse config with backwards compatibility with config.json pre 0.6.6
config_file_contents = open(os.getenv('CONFIG_FILE', "./config.json")).read()
config_file_contents = config_file_contents.replace("user_plane_ping_test", "dns")
config_file_contents = config_file_contents.replace("speedtest_iperf", "iperf_server")
config_file_contents = config_file_contents.replace("\"speedtest_ping_dns\": \"1.1.1.1\",", "")
CONF = json.loads(
    config_file_contents, object_hook=lambda d: namedtuple('X', d.keys())(*d.values())
)

logging.basicConfig(
    filename=CONF.log_file,
    format='%(asctime)s [%(levelname)s] %(message)s',
    level=logging.getLevelName(CONF.log_level)
)


class State(enum.Enum):
    error = "-1"
    disconnected = "0"
    connected = "1"

    @classmethod
    def has_value(cls, value):
        return value in cls._value2member_map_


class Modem():
    log = logging.getLogger('aether_edge_monitoring.Modem')

    read_timeout = 0.1

    def __init__(self, port, baudrate):
        self.port = port
        self.baudrate = baudrate
        self._response = None

    def get_modem_port(self):
        cmd = "ls " + CONF.modem.port
        sp = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE, universal_newlines=True)
        rc = sp.wait()
        ret,err = sp.communicate()
        if err != "" :
            logging.error("unable to find serial port " + err)

        ret = ret.replace(CONF.modem.port,"").strip()
        logging.info("Modem.get_modem_port found " + ret)
        return ret

    def connect(self):
        self.port=self.get_modem_port()
        logging.info("modem.connect Port: %s, BaudRate: %i",self.port,self.baudrate)
        self.serial = serial.Serial(
            port=self.port,
            baudrate=self.baudrate,
            timeout=1)

    def _write(self, command):
        if self.serial.inWaiting() > 0:
            self.serial.flushInput()

        self._response = b""

        self.serial.write(bytearray(command + "\r", "ascii"))
        read = self.serial.inWaiting()
        while True:
            if read > 0:
                self._response += self.serial.read(read)
            else:
                time.sleep(self.read_timeout)
            read = self.serial.inWaiting()
            if read == 0:
                break
        return self._response.decode("ascii").replace('\r\n', ' ')

    def write(self, command, wait_resp=True):
        response = self._write(command)
        self.log.debug("%s: %s", command, response)

        if wait_resp and "ERROR" in response:
            return False, None
        return True, response

    def get_state(self):
        success, result = self.write('AT+CGATT?')
        if not success or 'CGATT:' not in result:
            return State.error
        state = result.split('CGATT:')[1].split(' ')[0]
        return State(state)

    def close(self):
        self.serial.close()


def get_control_plane_state(modem):
    # Disable radio fuction
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=0' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=0' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=0' failed")
        return State.error

    # Wait until the modem is fully disconnected
    retry = 0
    state = None
    while retry < CONF.detach_timeout:
        state = modem.get_state()
        if state is State.disconnected:
            break
        time.sleep(1)
        retry += 1

    if state is not State.disconnected:
        logging.error("Failed to disconnect")
        return State.error

    time.sleep(2)
    # Enable radio function
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=1' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=1' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=1' failed")
        return State.error

    # Wait attach_timeout sec for the modem to be fully connected
    retry = 0
    while retry < CONF.attach_timeout:
        state = modem.get_state()
        if state is State.connected:
            break
        time.sleep(1)
        retry += 1
    # CGATT sometimes returns None
    if state is State.error:
        state = State.disconnected

    return state


def get_user_plane_state(modem):
    try:
        subprocess.check_output(
            "ping -c 3 " + CONF.ips.dns + ">/dev/null 2>&1",
            shell=True)
        return State.connected
    except subprocess.CalledProcessError as e:
        logging.warning("User plane test failed")
        return State.disconnected


def run_ping_test(ip, count):
    '''
    Runs the ping test
    Input: IP to ping, # times to ping
    Returns: Transmitted packets
             Received packets
             Median ping ms
             Min ping ms
             Avg ping ms
             Max ping ms
             Std Dev ping ms
    '''
    result = {'transmitted': 0,
              'received': 0,
              'median': 0.0,
              'min': 0.0,
              'avg': 0.0,
              'max': 0.0,
              'stddev': 0.0}
    if not ip:
        return result, True
    try:
        pingOutput = subprocess.check_output(
                    "ping -c " + str(count) + " " + \
                    ip, shell=True).decode("UTF-8").split()
        result['transmitted'] = int(pingOutput[-15])
        result['received'] = int(pingOutput[-12])
        if result['received'] > 0:
            pingValues = []

            # Hack for getting all ping values for median
            for word in pingOutput:
                if "time=" in word:
                    pingValues.append(float(word.split("=")[1]))
            result['median'] = round(median(pingValues), 3)

            pingResult = pingOutput[-2].split('/')
            result['min'] = float(pingResult[0])
            result['avg'] = float(pingResult[1])
            result['max'] = float(pingResult[2])
            result['stddev'] = float(pingResult[3])
        else:
            logging.error("No packets received during ping " + ip)
            return result, False
    except Exception as e:
        logging.error("Ping test failed for " + ip + ": %s", e)
        return result, False
    return result, True


def get_ping_test(modem):
    '''
    Prepares the ping test.
    Runs ping tests from 'ips' entry in config.json in order.
    Note: 'dry_run' IP entry runs 3 ping iterations. Other IPs run 10 iterations.
    '''
    speedtest_ping = {}
    status = True
    ping_test_passed = True

    for i in range(0, len(CONF.ips)):
        count = 10 if CONF.ips._fields[i] != "dry_run" else 3
        speedtest_ping[CONF.ips._fields[i]], status = run_ping_test(CONF.ips[i], count)
        if not status:
            ping_test_passed = False
            logging.error("Ping test failed. Not running further tests.")
    return speedtest_ping, ping_test_passed

def run_iperf_test(ip, port, time_duration, is_downlink):
    '''
    Runs iperf test to specified IP in the config file.
    - Runs for 10 seconds (10 iterations)
    - Retrieves downlink and uplink test results from json output
    '''
    result = 0.0
    if not ip or port == 0:
        return result
    maxRetries = 2
    err = None
    for _ in range(0, maxRetries):
        try:
            iperfResult = json.loads(subprocess.check_output(
                    "iperf3 -c " + ip +
                    " -p " + str(port) +
                    " -t " + str(time_duration) +
                    (" -R " if is_downlink else "") +
                    " --json", shell=True).decode("UTF-8"))
            received_mbps = iperfResult['end']['sum_received']['bits_per_second'] / 1000000
            sent_mbps = iperfResult['end']['sum_sent']['bits_per_second'] / 1000000.0
            result = received_mbps if is_downlink else sent_mbps
            return result
        except Exception as e:
            err = e
            time.sleep(5)
            pass
    logging.error("After " + str(maxRetries) + " retries, iperf test failed for " + ip + ": %s", err)
    return result


def get_iperf_test(modem):
    '''
    Prepares the iperf test.
    '''
    global hour_iperf_scheduled_time_last_ran
    speedtest_iperf = {}
    speedtest_iperf['cluster'] = {}

    if "iperf_schedule" in CONF._fields and len(CONF.iperf_schedule) > 0:
        if int(time.strftime("%H")) not in CONF.iperf_schedule: # not in the schedule
            hour_iperf_scheduled_time_last_ran = -1
            return None
        elif int(time.strftime("%H")) == hour_iperf_scheduled_time_last_ran: # already ran this hour
            return None
    hour_iperf_scheduled_time_last_ran = int(time.strftime("%H"))

    speedtest_iperf['cluster']['downlink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, True)
    speedtest_iperf['cluster']['uplink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, False)

    return speedtest_iperf


def get_signal_quality(modem):
    success, result = modem.write('AT+CESQ')
    logging.debug("get_signal_quality success %i result %s",success,result)
    if not success or 'CESQ: ' not in result:
        logging.error("Failed to get signal quality")

        return {'rsrq':0, 'rsrp':0}

    logging.debug("%s", result)
    tmp_rsrq = result.split('CESQ:')[1].split(',')[4]
    tmp_rsrp = result.split('CESQ:')[1].split(',')[5]

    rsrq = int(tmp_rsrq.strip())
    rsrp = int(tmp_rsrp.strip().split(' ')[0])
    result = {
        'rsrq': 0 if rsrq is 255 else rsrq,
        'rsrp': 0 if rsrp is 255 else rsrp
    }

    return result

def get_dongle_stats(modem):
    result = {'SuccessfulFetch' : False}
    XMLkeys = ["MAC",
               "PLMNStatus",
               "UICCStatus",
               "IMEI",
               "IMSI",
               "PLMNSelected",
               "MCC",
               "MNC",
               "PhyCellID",
               "CellGlobalID",
               "Band",
               "EARFCN",
               "BandWidth",
               "ServCellState",
               "Connection",
               "IPv4Addr"]
    dongleStatsXML = None
    try:
        dongleStatsXML = ET.fromstring(subprocess.check_output("curl -u admin:admin 'http://192.168.0.1:8080/cgi-bin/ltestatus.cgi?Command=Status'", shell=True).decode("UTF-8"))
    except Exception as e:
        logging.error("Failed to fetch dongle stats from URL: " + str(e))
        return result
    try:
        for key in XMLkeys:
            try:
                result[key] = dongleStatsXML.find(key).text
            except AttributeError as e:
                logging.warn("Failed to find " + key + " in XML.")
                result[key] = ""
        result["SuccessfulFetch"] = True
    except Exception as e:
        logging.error("Failed to fetch dongle stats from XML: " + str(e))
        return result
    return result


def report_status(signal_quality, dongle_stats, cp_state=None, up_state=None, speedtest_ping=None, speedtest_iperf=None):
    report = {
        'name': CONF.edge_name,
        'status': {
            'control_plane': "disconnected",
            'user_plane': "disconnected"
        },
        'dongle_stats': {
            'SuccessfulFetch' : False
        },
        'speedtest': {
            'ping': {
                'dns': {
                    'transmitted' : 0,
                    'received' : 0,
                    'median' : 0.0,
                    'min': 0.0,
                    'avg': 0.0,
                    'max': 0.0,
                    'stddev': 0.0
                }
            },
            'iperf': {
                'cluster': {
                    'downlink': 0.0,
                    'uplink': 0.0
                }
            }
        },
        'signal_quality': {
            'rsrq': 0,
            'rsrp': 0
        }
    }

    if cp_state is not None:
        report['status']['control_plane'] = cp_state.name
    if up_state is not None:
        report['status']['user_plane'] = up_state.name
    if speedtest_ping is not None:
        report['speedtest']['ping'] = speedtest_ping
    if speedtest_iperf is not None:
        report['speedtest']['iperf'] = speedtest_iperf
    report['signal_quality'] = signal_quality
    report['dongle_stats'] = dongle_stats

    logging.info("Sending report %s", report)
    global cycles
    cycles += 1
    logging.info("Number of cycles since modem restart %i",cycles)

    try:
        c = pycurl.Curl()
        c.setopt(pycurl.URL, CONF.report_url)
        c.setopt(pycurl.POST, True)
        c.setopt(pycurl.HTTPHEADER, ['Content-Type: application/json'])
        c.setopt(pycurl.TIMEOUT, 10)
        c.setopt(pycurl.POSTFIELDS, json.dumps(report))
        if "report_in_band" in CONF._fields and CONF.report_in_band and \
           "iface" in CONF.modem._fields and CONF.modem.iface:
           c.setopt(pycurl.INTERFACE, CONF.modem.iface)
        elif "report_iface" in CONF._fields and CONF.report_iface:
           c.setopt(pycurl.INTERFACE, CONF.report_iface)
        c.perform()
        c.close()
    except Exception as e:
        logging.error("Failed to send report: " + str(e))

    time.sleep(CONF.report_interval)

def reset_usb():
    try:
        # Attempt to run uhubctl
        if (int(subprocess.call("which uhubctl",shell=True)) == 0):
            cmd = "/usr/sbin/uhubctl -a 0 -l 2" # -a 0 = action is shutdown -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Shutting down usb hub 2 results %s" , ret)
            time.sleep(10)# let power down process settle out
            cmd = "/usr/sbin/uhubctl -a 1 -l 2" # -a 1 = action is start -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Starting up usb hub 2 results %s" , ret)
            time.sleep(10) #allow dbus to finish
            global cycles
            cycles = 0
        else:
            reboot(120)
    except Exception as e:
        logging.error("Failed to run uhubctl: %s", e)
        reboot(120)

def reboot(delay):
    logging.error("Failed to run uhubctl. Reboot system in " + str(delay) + " second(s).")
    time.sleep(delay)
    subprocess.check_output("sudo shutdown -r now", shell=True)

def main():
    global cycles
    global hour_iperf_scheduled_time_last_ran
    cycles = 0
    hour_iperf_scheduled_time_last_ran = -1

    try:
        if "report_in_band" in CONF._fields and \
        "iface" in CONF.modem._fields and CONF.modem.iface:
            if CONF.report_in_band: # need to add default gateway if reporting in-band
                subprocess.check_output("sudo route add default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
            else:
                subprocess.check_output("sudo route del default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
    except Exception as e:
        logging.error("Failed to change default route for modem: " + str(e))

    for ip in CONF.ips:
        if not ip:
            continue
        try:
            subprocess.check_output("sudo ip route replace {}/32 via {}".format(
                ip, CONF.modem.ip_addr), shell=True)
        except subprocess.CalledProcessError as e:
            logging.error("Failed to add routes: " + str(e.returncode) + str(e.output))
            time.sleep(10) # Sleep for 10 seconds before retry
            sys.exit(1)

    modem = Modem(CONF.modem.port, CONF.modem.baud)
    try:
        modem.connect()
    except serial.serialutil.SerialException as e:
        logging.error("Failed to connect the modem for %s", e)
        sys.exit(1)

    while True:
        dongle_stats = get_dongle_stats(modem)

        signal_quality = get_signal_quality(modem)

        cp_state = get_control_plane_state(modem)
        if cp_state is State.error:
            logging.error("Modem is in error state.")
            reset_usb()
            sys.exit(1)
        if cp_state is State.disconnected:
            # Failed to attach, don't need to run other tests
            report_status(signal_quality, dongle_stats)
            continue

        up_state = get_user_plane_state(modem)
        if up_state is State.disconnected:
            # Basic user plane test failed, don't need to run the rest of tests
            report_status(signal_quality, dongle_stats, cp_state)
            continue

        speedtest_ping, speedtest_status = get_ping_test(modem)
        if speedtest_status:
            speedtest_iperf = get_iperf_test(modem)
        else:
            report_status(signal_quality, dongle_stats, cp_state, up_state, speedtest_ping)
            continue

        report_status(signal_quality, dongle_stats, cp_state, up_state, speedtest_ping, speedtest_iperf)

    modem.close()


if __name__ == "__main__":
    main()
