#!/usr/bin/env python3

# Copyright 2020-present Open Networking Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import sys
import os
import json
import logging
import enum
import requests
import time
import serial
import subprocess
from collections import namedtuple

"""
"Simple" script that checks Aether network operational status periodically
by controlling the attached 4G/LTE modem with AT commands and
report the result to the central monitoring server.
"""


CONF = json.loads(
    open(os.getenv("CONFIG_FILE", "./config.json")).read(),
    object_hook=lambda d: namedtuple("X", d.keys())(*d.values()),
)

logging.basicConfig(
    filename=CONF.log_file,
    format="%(asctime)s [%(levelname)s] %(message)s",
    level=logging.ERROR,
)

report = {
    "name": CONF.edge_name,
    "status": {"control_plane": None, "user_plane": None},
    "speedtest": {
        "ping": {"dns": {"min": None, "avg": None, "max": None, "stddev": None}}
    },
}


class State(enum.Enum):
    error = "-1"
    disconnected = "0"
    connected = "1"

    @classmethod
    def has_value(cls, value):
        return value in cls._value2member_map_


class Modem:
    log = logging.getLogger("aether_edge_monitoring.Modem")

    read_timeout = 0.1

    def __init__(self, port, baudrate):
        self.port = port
        self.baudrate = baudrate
        self._response = None

    def connect(self):
        self.serial = serial.Serial(port=self.port, baudrate=self.baudrate, timeout=1)

    def _write(self, command):
        if self.serial.inWaiting() > 0:
            self.serial.flushInput()

        self._response = b""

        self.serial.write(bytearray(command + "\r", "ascii"))
        read = self.serial.inWaiting()
        while True:
            if read > 0:
                self._response += self.serial.read(read)
            else:
                time.sleep(self.read_timeout)
            read = self.serial.inWaiting()
            if read == 0:
                break
        return self._response.decode("ascii").replace("\r\n", " ")

    def write(self, command, wait_resp=True):
        response = self._write(command)
        self.log.debug("%s: %s", command, response)

        if wait_resp and "ERROR" in response:
            return False, None
        return True, response

    def is_connected(self):
        success, result = self.write("AT+CGATT?")
        if not success or "CGATT:" not in result:
            return State.error
        state = result.split("CGATT:")[1].split(" ")[0]
        return State(state)

    def close(self):
        self.serial.close()


def get_control_plane_state(modem):
    # Delete the existing session
    # "echo" works more stable than serial for this action
    # success, result = modem.write('AT+CFUN=0')
    logging.debug("echo 'AT+CFUN=0' > " + CONF.modem.port)
    success = os.system("echo 'AT+CFUN=0' > " + CONF.modem.port)
    logging.debug("result: %s", success)
    if success is not 0:
        msg = "Write 'AT+CFUN=0' failed"
        logging.error(msg)
        return State.error, msg

    # Wait until the modem is fully disconnected
    retry = 0
    state = None
    while retry < 5:
        state = modem.is_connected()
        if state is State.disconnected:
            break
        time.sleep(1)
        retry += 1

    # Consider the modem is not responding if disconnection failed
    if state is not State.disconnected:
        msg = "Failed to disconnect."
        logging.error(msg)
        return State.error, msg

    time.sleep(2)
    # Create a new session
    # "echo" works more stable than serial for this action
    # success, result = modem.write('AT+CGATT=1')
    logging.debug("echo 'AT+CFUN=1' > " + CONF.modem.port)
    success = os.system("echo 'AT+CFUN=1' > " + CONF.modem.port)
    logging.debug("result: %s", success)
    if success is not 0:
        msg = "Write 'AT+CFUN=1' failed"
        logging.error(msg)
        return State.error, msg

    # Give 10 sec for the modem to be fully connected
    retry = 0
    while retry < 30:
        state = modem.is_connected()
        if state is State.connected:
            break
        time.sleep(1)
        retry += 1
    # CGATT sometimes returns None
    if state is State.error:
        state = State.disconnected

    time.sleep(2)
    return state, None


def get_user_plane_state(modem):
    resp = os.system("ping -c 3 " + CONF.ips.user_plane_ping_test + ">/dev/null 2>&1")
    return State.connected if resp is 0 else State.disconnected, None


def run_ping_test(ip, count):
    """
    Runs the ping test
    Input: IP to ping, # times to ping
    Returns: dict of the min/avg/max/stddev numbers from the ping command result
    """
    result = {"min": 0.0, "avg": 0.0, "max": 0.0, "stddev": 0.0}
    try:
        pingResult = (
            subprocess.check_output(
                "ping -c " + str(count) + " " + ip + " | tail -1 | awk '{print $4}'",
                shell=True,
            )
            .decode("UTF-8")
            .split("/")
        )
        result = {
            "min": float(pingResult[0]),
            "avg": float(pingResult[1]),
            "max": float(pingResult[2]),
            "stddev": float(pingResult[3]),
        }
    except Exception as e:
        logging.error("Ping test failed for " + ip + ": %s", e)
    return result


def get_ping_test(modem):
    """
    Each ping test result saves the min/avg/max/stddev to dict.
    1) Performs ping test to Google Public DNS for 10 iterations.
    2) # TODO: Performs ping to device on network.
    """
    speedtest_ping = {}
    speedtest_ping["dns"] = run_ping_test(CONF.ips.speedtest_ping_dns, 10)
    return speedtest_ping


def report_status(cp_state, up_state, speedtest_ping):
    report["status"]["control_plane"] = cp_state.name
    report["status"]["user_plane"] = up_state.name
    report["speedtest"]["ping"] = speedtest_ping

    logging.info("Sending report %s", report)
    try:
        result = requests.post(CONF.report_url, json=report)
    except requests.exceptions.ConnectionError:
        logging.error("Failed to report for %s", e)
        pass
    try:
        result.raise_for_status()
    except requests.exceptions.HTTPError as e:
        logging.error("Failed to report for %s", e)
        pass


def main():
    modem = Modem(CONF.modem.port, CONF.modem.baud)
    try:
        modem.connect()
    except serial.serialutil.SerialException as e:
        logging.error("Failed to connect the modem for %s", e)
        sys.exit(1)

    for ip in CONF.ips:
        success = os.system(
            "sudo ip route replace {}/32 via {}".format(ip, CONF.modem.ip_addr)
        )
        if success is not 0:
            logging.error("Failed to add test routing to " + ip)
            sys.exit(1)

    while True:
        cp_state, cp_msg = get_control_plane_state(modem)
        up_state, up_msg = get_user_plane_state(modem)
        speedtest_ping = get_ping_test(modem)

        if cp_state is State.error:
            logging.error("Modem is in error state.")
            sys.exit(1)

        report_status(cp_state, up_state, speedtest_ping)
        time.sleep(CONF.report_interval)

    modem.close()


if __name__ == "__main__":
    main()
