#!/usr/bin/env python3

# Copyright 2020-present Open Networking Foundation
#
# SPDX-License-Identifier: LicenseRef-ONF-Member-Only-1.0

import sys
import os
import json
import logging
import enum
import pycurl
import serial
import subprocess
import time
from collections import namedtuple
from statistics import median
import xml.etree.ElementTree as ET
import traceback

'''
"Simple" script that checks Aether network operational status periodically
by controlling the attached 4G/LTE modem with AT commands and
report the result to the central monitoring server.
'''

USE_MODEM_CMDS = False

# Parse config with backwards compatibility with config.json pre 0.6.6
config_file_contents = open(os.getenv('CONFIG_FILE', "./config.json")).read()
config_file_contents = config_file_contents.replace("user_plane_ping_test", "dns")
config_file_contents = config_file_contents.replace("speedtest_iperf", "iperf_server")
config_file_contents = config_file_contents.replace("\"speedtest_ping_dns\": \"8.8.8.8\",", "")
# replace 1.1.1.1 with 8.8.8.8
config_file_contents = config_file_contents.replace('1.1.1.1', '8.8.8.8')
CONF = json.loads(
    config_file_contents, object_hook=lambda d: namedtuple('X', d.keys())(*d.values())
)

logging.basicConfig(
    filename=CONF.log_file,
    format='%(asctime)s [%(levelname)s] %(message)s',
    level=logging.getLevelName(CONF.log_level)
)


class State(enum.Enum):
    error = "-1"
    disconnected = "0"
    connected = "1"

    @classmethod
    def has_value(cls, value):
        return value in cls._value2member_map_


class Modem():
    log = logging.getLogger('aether_edge_monitoring.Modem')

    read_timeout = 0.1

    def __init__(self, port, baudrate):
        self.port = port
        self.baudrate = baudrate
        self._response = None

    def get_modem_port(self):
        cmd = "ls " + CONF.modem.port
        sp = subprocess.Popen(cmd,shell=True,stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE, universal_newlines=True)
        sp.wait()
        ret,err = sp.communicate()
        if err != "" :
            self.log.error("unable to find serial port " + err)

        ret = ret.replace(CONF.modem.port,"").strip()
        self.log.info("Modem.get_modem_port found " + ret)
        return ret

    def connect(self):
        self.port=self.get_modem_port()
        self.log.info("modem.connect Port: %s, BaudRate: %i",self.port,self.baudrate)
        self.serial = serial.Serial(
            port=self.port,
            baudrate=self.baudrate,
            timeout=1)

    def _write(self, command):
        if self.serial.inWaiting() > 0:
            self.serial.flushInput()

        self._response = b""

        self.serial.write(bytearray(command + "\r", "ascii"))
        read = self.serial.inWaiting()
        while True:
            self.log.debug("Waiting for write to complete...")
            if read > 0:
                self._response += self.serial.read(read)
            else:
                time.sleep(self.read_timeout)
            read = self.serial.inWaiting()
            if read == 0:
                break
        self.log.debug("Write complete...")
        return self._response.decode("ascii").replace('\r\n', ' ')

    def write(self, command, wait_resp=True):
        response = self._write(command)
        self.log.debug("%s: %s", command, response)

        if wait_resp and "ERROR" in response:
            return False, None
        return True, response

    def get_state(self, counters):
        success, result = self.write('AT+CGATT?')
        if not success or 'CGATT:' not in result:
            logging.error("AT+CGATT modem cmd failed")
            counters['modem_cgatt_error'] += 1
            return State.error
        state = result.split('CGATT:')[1].split(' ')[0]
        return State(state)

    def close(self):
        self.serial.close()


def get_control_plane_state(modem, counters, dongle_stats=None):
    if not modem and dongle_stats:
        if dongle_stats['Connection'] == 'Connected':
            return State.connected
        else:
            logging.error("Dongle not connected: {}".format(dongle_stats['Connection']))
            counters['dongle_connect_error'] += 1
            return State.disconnected

    # Disable radio fuction
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=0' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=0' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=0' failed: {}".format(e))
        counters['modem_cfun0_error'] += 1
        return State.error

    # Wait until the modem is fully disconnected
    retry = 0
    state = None
    while retry < CONF.detach_timeout:
        state = modem.get_state(counters)
        if state is State.disconnected:
            break
        time.sleep(1)
        retry += 1

    if state is not State.disconnected:
        logging.error("Failed to disconnect")
        return State.error

    time.sleep(2)
    # Enable radio function
    # "echo" works more stable than serial for this action
    try:
        logging.debug("echo 'AT+CFUN=1' > " + modem.port)
        subprocess.check_output(
            "echo 'AT+CFUN=1' > " + modem.port, shell=True)
    except subprocess.CalledProcessError as e:
        logging.error("Write 'AT+CFUN=1' failed: {}".format(e))
        counters['modem_cfun1_error'] += 1
        return State.error

    # Wait attach_timeout sec for the modem to be fully connected
    retry = 0
    while retry < CONF.attach_timeout:
        state = modem.get_state(counters)
        if state is State.connected:
            break
        time.sleep(1)
        retry += 1
    # CGATT sometimes returns None
    if state is State.error:
        state = State.disconnected

    return state


def dry_run_ping_test(counters):
    if "dry_run" in CONF.ips._fields and CONF.ips.dry_run: # run dry_run latency test as user plane test
        return do_ping(CONF.ips.dry_run, 10)
    else: # run default user plane test
        try:
            subprocess.check_output(
                "ping -I {} -c 3 {} >/dev/null 2>&1".format(CONF.modem.iface, CONF.ips.dns),
                shell=True)
            return None, True
        except subprocess.CalledProcessError as e:
            logging.warning("Ping failed for {}: {}".format(CONF.ips.dns, e))
            return None, False


def do_ping(ip, count):
    '''
    Runs the ping test
    Input: IP to ping, # times to ping
    Returns: Transmitted packets
             Received packets
             Median ping ms
             Min ping ms
             Avg ping ms
             Max ping ms
             Std Dev ping ms
    '''
    result = {'transmitted': 0,
              'received': 0,
              'median': 0.0,
              'min': 0.0,
              'avg': 0.0,
              'max': 0.0,
              'stddev': 0.0}
    if not ip:
        return result, True
    try:
        logging.debug("Pinging {}".format(ip))
        pingCmd = "ping -I {} -c {} {}".format(CONF.modem.iface, str(count), ip)
        pingOutput = subprocess.check_output(pingCmd, shell=True).decode("UTF-8").split()
        result['transmitted'] = int(pingOutput[-15])
        result['received'] = int(pingOutput[-12])
        if result['received'] > 0:
            pingValues = []

            # Hack for getting all ping values for median
            for word in pingOutput:
                if "time=" in word:
                    pingValues.append(float(word.split("=")[1]))
            result['median'] = round(median(pingValues), 3)

            pingResult = pingOutput[-2].split('/')
            result['min'] = float(pingResult[0])
            result['avg'] = float(pingResult[1])
            result['max'] = float(pingResult[2])
            result['stddev'] = float(pingResult[3])
        else:
            logging.error("No packets received during ping " + ip)
            return result, False
    except Exception as e:
        logging.error("Ping test failed for " + ip + ": %s", e)
        traceback.print_exc()
        return result, False
    return result, True


def ping_test(modem, dry_run_latency=None):
    '''
    Prepares the ping test.
    Runs ping tests from 'ips' entry in config.json in order.
    Note: 'dry_run' is not run here; it is run during the user plane test.
    '''
    speedtest_ping = {}
    status = True
    ping_test_passed = True

    if dry_run_latency:
        speedtest_ping["dry_run"] = dry_run_latency

    for i in range(0, len(CONF.ips)):
        if CONF.ips._fields[i] == "dry_run":
            continue
        count = 10
        speedtest_ping[CONF.ips._fields[i]], status = do_ping(CONF.ips[i], count)
        if not status:
            ping_test_passed = False
            logging.error("Ping failed: {}".format(CONF.ips[i]))
    return speedtest_ping, ping_test_passed

def run_iperf_test(ip, port, time_duration, is_downlink, counters):
    '''
    Runs iperf test to specified IP in the config file.
    - Runs for 10 seconds (10 iterations)
    - Retrieves downlink and uplink test results from json output
    '''
    result = 0.0
    if not ip or port == 0:
        return result
    maxRetries = 2
    err = None
    for _ in range(0, maxRetries):
        try:
            iperfResult = json.loads(subprocess.check_output(
                    "iperf3 -c " + ip +
                    " -p " + str(port) +
                    " -t " + str(time_duration) +
                    (" -R " if is_downlink else "") +
                    " --json", shell=True).decode("UTF-8"))
            received_mbps = iperfResult['end']['sum_received']['bits_per_second'] / 1000000
            sent_mbps = iperfResult['end']['sum_sent']['bits_per_second'] / 1000000.0
            result = received_mbps if is_downlink else sent_mbps
            return result
        except Exception as e:
            err = e
            counters['iperf_error'] += 1
            time.sleep(5)
            pass
    logging.error("After " + str(maxRetries) + " retries, iperf test failed for " + ip + ": %s", err)
    return result


def iperf_test(counters):
    '''
    Prepares the iperf test.
    '''
    global hour_iperf_scheduled_time_last_ran
    speedtest_iperf = {}
    speedtest_iperf['cluster'] = {}

    if "iperf_schedule" in CONF._fields and len(CONF.iperf_schedule) > 0:
        if int(time.strftime("%H")) not in CONF.iperf_schedule: # not in the schedule
            hour_iperf_scheduled_time_last_ran = -1
            return None
        elif int(time.strftime("%H")) == hour_iperf_scheduled_time_last_ran: # already ran this hour
            return None
    hour_iperf_scheduled_time_last_ran = int(time.strftime("%H"))

    speedtest_iperf['cluster']['downlink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, True, counters)
    speedtest_iperf['cluster']['uplink'] = run_iperf_test(CONF.ips.iperf_server, CONF.iperf_port, 10, False, counters)

    return speedtest_iperf


def get_signal_quality(modem, counters, dongle_stats=None):
    if not modem and dongle_stats:
        if dongle_stats['RSRQ'] != '' and dongle_stats['RSRP'] != '':
            rsrq = int((float(dongle_stats['RSRQ']) + 19.5) * 2)
            rsrp = int(float(dongle_stats['RSRP']) + 140)
            return {'rsrq': rsrq, 'rsrp': rsrp}
        else:
            counters['dongle_rsrp_rsrq_error'] += 1
            return {'rsrq': 0, 'rsrp': 0}

    # Fall back to modem cmds

    success, result = modem.write('AT+CESQ')
    logging.debug("get_signal_quality success %i result %s",success,result)
    if not success or 'CESQ: ' not in result:
        logging.error("Failed to get signal quality")
        counters['modem_cesq_error'] += 1
        return {'rsrq':0, 'rsrp':0}

    tmp_rsrq = result.split('CESQ:')[1].split(',')[4]
    tmp_rsrp = result.split('CESQ:')[1].split(',')[5]

    rsrq = int(tmp_rsrq.strip())
    rsrp = int(tmp_rsrp.strip().split(' ')[0])
    result = {
        'rsrq': 0 if rsrq == 255 else rsrq,
        'rsrp': 0 if rsrp == 255 else rsrp
    }

    return result


def get_dongle_stats(counters):
    result = {'SuccessfulFetch' : False}
    if "report_in_band" in CONF._fields:
        result['inBandReporting'] = CONF.report_in_band
    else:
        result['inBandReporting'] = False
    XMLkeys = ["MAC",
               "PLMNStatus",
               "UICCStatus",
               "IMEI",
               "IMSI",
               "PLMNSelected",
               "MCC",
               "MNC",
               "PhyCellID",
               "CellGlobalID",
               "Band",
               "EARFCN",
               "BandWidth",
               "RSRP",
               "RSRQ",
               "ServCellState",
               "Connection",
               "IPv4Addr"]
    dongleStatsXML = None
    try:
        dongleStatsXML = ET.fromstring(subprocess.check_output("curl -u admin:admin -s 'http://192.168.0.1:8080/cgi-bin/ltestatus.cgi?Command=Status'", shell=True).decode("UTF-8"))
    except Exception as e:
        logging.error("Failed to fetch dongle stats from URL: " + str(e))
        counters['dongle_read_error'] += 1
        return result
    try:
        for key in XMLkeys:
            try:
                result[key] = dongleStatsXML.find(key).text
            except AttributeError:
                logging.error("Failed to find " + key + " in XML")
                counters['dongle_read_error'] += 1
                result[key] = ""
        result["SuccessfulFetch"] = True
    except Exception as e:
        logging.error("Failed to fetch dongle stats from XML: " + str(e))
        counters['dongle_read_error'] += 1
        return result
    return result


def report_status(counters, signal_quality, dongle_stats, cp_state=None, up_state=None, speedtest_ping=None, speedtest_iperf=None):
    report = {
        'name': CONF.edge_name,
        'status': {
            'control_plane': "disconnected",
            'user_plane': "disconnected"
        },
        'dongle_stats': {
            'SuccessfulFetch' : False
        },
        'speedtest': {
            'ping': {
                'dns': {
                    'transmitted' : 0,
                    'received' : 0,
                    'median' : 0.0,
                    'min': 0.0,
                    'avg': 0.0,
                    'max': 0.0,
                    'stddev': 0.0
                }
            },
            'iperf': {
                'cluster': {
                    'downlink': 0.0,
                    'uplink': 0.0
                }
            }
        },
        'signal_quality': {
            'rsrq': 0,
            'rsrp': 0
        }
    }

    if cp_state is not None:
        report['status']['control_plane'] = cp_state.name
    if up_state is not None:
        report['status']['user_plane'] = up_state.name
    if speedtest_ping is not None:
        report['speedtest']['ping'] = speedtest_ping
    if speedtest_iperf is not None:
        report['speedtest']['iperf'] = speedtest_iperf
    report['signal_quality'] = signal_quality
    report['dongle_stats'] = dongle_stats
    report['counters'] = counters

    logging.info("Sending report %s", report)
    global cycles
    cycles += 1
    logging.info("Number of cycles since modem restart %i",cycles)

    try:
        interface = None
        report_via_modem = "report_in_band" in CONF._fields and CONF.report_in_band and \
            "iface" in CONF.modem._fields and CONF.modem.iface
        report_via_given_iface = "report_iface" in CONF._fields and CONF.report_iface

        c = pycurl.Curl()
        c.setopt(pycurl.URL, CONF.report_url)
        c.setopt(pycurl.POST, True)
        c.setopt(pycurl.HTTPHEADER, ['Content-Type: application/json'])
        c.setopt(pycurl.TIMEOUT, 10)
        c.setopt(pycurl.POSTFIELDS, json.dumps(report))
        c.setopt(pycurl.WRITEFUNCTION, lambda x: None) # don't output to console

        if report_via_modem: # report in-band
            interface = CONF.modem.iface
            c.setopt(pycurl.INTERFACE, interface)
        elif report_via_given_iface: # report over given interface
            interface = CONF.report_iface
            c.setopt(pycurl.INTERFACE, interface)
        else:
            interface = "default"

        try:
            c.perform()
            logging.info("Report sent via " + interface + "!")
        except Exception as e:
            logging.error("Failed to send report in-band: " + str(e))
            counters['report_send_error'] += 1
            if report_via_modem and report_via_given_iface:
                logging.warning("Attempting to send report via " + str(CONF.report_iface) + ".")
                interface = CONF.report_iface
                c.setopt(pycurl.INTERFACE, interface)
                c.perform()
                logging.info("Report sent via " + interface + "!")
            else:
                logging.error("Failed to send report out-of-band: " + str(e))
        c.close()
    except Exception as e:
        logging.error("Failed to send report: " + str(e))
        c.close()

def reset_usb():
    try:
        # Attempt to run uhubctl
        if (int(subprocess.call("which uhubctl",shell=True)) == 0):
            cmd = "/usr/sbin/uhubctl -a 0 -l 2" # -a 0 = action is shutdown -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Shutting down usb hub 2 results %s" , ret)
            time.sleep(10)# let power down process settle out
            cmd = "/usr/sbin/uhubctl -a 1 -l 2" # -a 1 = action is start -l 2 location = bus 2 on pi controls power to all hubs
            ret = subprocess.call(cmd,shell=True)
            logging.info("Starting up usb hub 2 results %s" , ret)
            time.sleep(10) #allow dbus to finish
            global cycles
            cycles = 0
        else:
            reboot(120)
    except Exception as e:
        logging.error("Failed to run uhubctl: %s", e)
        reboot(120)

def reboot(delay):
    logging.error("Failed to run uhubctl. Reboot system in " + str(delay) + " second(s).")
    time.sleep(delay)
    subprocess.check_output("sudo shutdown -r now", shell=True)

def init_counters():
    return {
        'dongle_read_error': 0,
        'dongle_connect_error': 0,
        'dongle_rsrp_rsrq_error': 0,
        'modem_cfun0_error': 0,
        'modem_cfun1_error': 0,
        'modem_cgatt_error': 0,
        'modem_cesq_error': 0,
        'dry_run_ping_error': 0,
        'ping_error': 0,
        'iperf_error': 0,
        'report_send_error': 0
    }

def clear_counters(counters):
    for x in counters:
        counters[x] = 0

def main():
    global cycles
    global hour_iperf_scheduled_time_last_ran
    cycles = 0
    hour_iperf_scheduled_time_last_ran = -1

    counters = init_counters()

    try:
        if "report_in_band" in CONF._fields and \
        "iface" in CONF.modem._fields and CONF.modem.iface:
            if CONF.report_in_band: # need to add default gateway if reporting in-band
                subprocess.check_output("sudo route add default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
            else:
                subprocess.check_output("sudo route del default gw " + CONF.modem.ip_addr + " " + CONF.modem.iface + " || true", shell=True)
    except Exception as e:
        logging.error("Failed to change default route for modem: " + str(e))

    for ip in CONF.ips:
        if not ip:
            continue
        try:
            subprocess.check_output("sudo ip route replace {}/32 via {}".format(
                ip, CONF.modem.ip_addr), shell=True)
        except subprocess.CalledProcessError as e:
            logging.error("Failed to add routes: " + str(e.returncode) + str(e.output))
            time.sleep(10) # Sleep for 10 seconds before retry
            sys.exit(1)

    if USE_MODEM_CMDS:
        modem = Modem(CONF.modem.port, CONF.modem.baud)
        try:
            modem.connect()
        except serial.serialutil.SerialException as e:
            logging.error("Failed to connect the modem for %s", e)
            sys.exit(1)
    else:
        modem = None

    connect_retries = 0
    while True:
        dongle_retries = 0
        dongle_stats = get_dongle_stats(counters)
        while not dongle_stats['SuccessfulFetch']:
            dongle_retries += 1
            if dongle_retries > 10:
                logging.warning("Rebooting Pi: dongle not readable")
                os.system("shutdown /r /t 0")
                sys.exit(1)
            dongle_stats = get_dongle_stats(counters)

        cp_state = get_control_plane_state(modem, counters, dongle_stats)

        if cp_state != State.connected:
            logging.error("Dongle not connected")
            connect_retries += 1
            if connect_retries > 10:
                logging.warning("Rebooting Pi: dongle not connected")
                os.system("shutdown /r /t 0")
                sys.exit(1)

        signal_quality = get_signal_quality(modem, counters, dongle_stats)

        dry_run_ping_latency, dry_run_ping_result = dry_run_ping_test(counters)
        if not dry_run_ping_result:
            logging.error("Dry run ping failed")
            counters['dry_run_ping_error'] += 1

        ping_latency, ping_result = ping_test(modem, dry_run_ping_latency)
        if not ping_result:
            logging.error("Ping test failed")
            counters['ping_error'] += 1

        # If either of the ping tests pass, then declare user plane connected
        if dry_run_ping_result or ping_result:
            up_state = State.connected
        else:
            up_state = State.disconnected

        speedtest_iperf = iperf_test(counters)

        report_status(counters, signal_quality, dongle_stats, cp_state, up_state, ping_latency, speedtest_iperf)
        counters = clear_counters(counters)
        time.sleep(CONF.report_interval)

    modem.close()


if __name__ == "__main__":
    main()
