#!/usr/bin/env python3

# SPDX-FileCopyrightText: © 2021 Open Networking Foundation <support@opennetworking.org>
# SPDX-License-Identifier: Apache-2.0

# device.py

import sys
import netaddr

from .utils import logger
from .container import DeviceContainer, VirtualMachineContainer, PrefixContainer


class AssignedObject:
    """
    Assigned Object is either a Device or Virtual Machine, which function
    nearly identically in the NetBox data model.

    This parent class holds common functions for those two child classes

    An assignedObject (device or VM) should have following attributes:
    - self.data: contains the original copy of data from NetBox
    - self.id: Device ID or VM ID
    - self.interfaces: A dictionary contains interfaces belong to this AO
                       the interface dictionary looks like:

    {
        "eno1": {
            "address": ["192.168.0.1/24", "192.168.0.2/24"],
            "instance": <interface_instance>,
            "isPrimary": True,
            "mgmtOnly": False,
            "isVirtual": False
        }
    }
    """

    objects = dict()

    def __init__(self, data):
        from .utils import netboxapi

        self.data = data
        self.nbapi = netboxapi

        # The AssignedObject attributes
        self.id = self.data.id
        self.tenant = None
        self.primary_ip = None
        self.primary_iface = None

        # In Netbox, we use FQDN as the Device name, but in the script,
        # we use the first segment to be the name of device.
        # For example, if the device named "mgmtserver1.stage1.menlo" on Netbox,
        #  then we will have "mgmtserver1" as name.
        self.fullname = self.data.name
        self.name = self.fullname.split(".")[0]

        # The device role which can be ["server", "router", "switch", ...]
        self.role = None

        # The NetBox objects related with this AssignedObject
        self.interfaces = dict()
        self.services = None

        # Generated configuration for ansible playbooks
        self.netplan_config = dict()
        self.extra_config = dict()

        if self.__class__ == Device:
            self.role = self.data.device_role.slug
            self.services = self.nbapi.ipam.services.filter(device_id=self.id)
            interfaces = self.nbapi.dcim.interfaces.filter(device_id=self.id)
            ip_addresses = self.nbapi.ipam.ip_addresses.filter(device_id=self.id)
        elif self.__class__ == VirtualMachine:
            self.role = self.data.role.slug
            self.services = self.nbapi.ipam.services.filter(virtual_machine_id=self.id)
            interfaces = self.nbapi.virtualization.interfaces.filter(
                virtual_machine_id=self.id
            )
            ip_addresses = self.nbapi.ipam.ip_addresses.filter(
                virtual_machine_id=self.id
            )

        self.primary_ip = self.data.primary_ip

        for interface in interfaces:
            # The Device's interface structure is different from VM's interface
            # VM interface doesn't have mgmt_only and type, Therefore,
            # the default value of mgmtOnly is False, isVirtual is True

            self.interfaces[interface.name] = {
                "addresses": list(),
                "mac_address": interface.mac_address,
                "instance": interface,
                "isPrimary": False,
                "mgmtOnly": getattr(interface, "mgmt_only", False),
                "isVirtual": interface.type.value == "virtual"
                if hasattr(interface, "type")
                else True,
            }

        for address in ip_addresses:
            interface = self.interfaces[address.assigned_object.name]
            interface["addresses"].append(address.address)

            # ipam.ip_addresses doesn't have primary tag,
            # the primary tag is only available is only in the Device.
            # So we need to compare address to check which one is primary ip
            try:
                if address.address == self.primary_ip.address:
                    interface["isPrimary"] = True
                    self.primary_iface = interface
            except AttributeError:
                logger.error("Error with primary address for device %s", self.fullname)

            # mgmt_only = False is a hack for VirtualMachine type
            if self.__class__ == VirtualMachine:
                interface["instance"].mgmt_only = False

    def __repr__(self):
        return str(dict(self.data))

    @property
    def type(self):
        return "AssignedObject"

    @property
    def internal_interfaces(self):
        """
        The function internal_interfaces
        """

        ret = dict()
        for intfName, interface in self.interfaces.items():
            if (
                not interface["isPrimary"]
                and not interface["mgmtOnly"]
                and interface["addresses"]
            ):
                ret[intfName] = interface

        return ret

    def generate_netplan(self):
        """
        Get the interface config of specific server belongs to this tenant
        """

        if self.netplan_config:
            return self.netplan_config

        primary_if = None
        for interface in self.interfaces.values():
            if interface["isPrimary"] is True:
                primary_if = interface["instance"]

        if primary_if is None:
            logger.error("The primary interface wasn't set for device %s", self.name)
            return dict()

        # Initialize the part of "ethernets" configuration
        self.netplan_config["ethernets"] = dict()

        # If the current selected device is a Router
        if (isinstance(self, Device) and self.data.device_role.name == "Router") or (
            isinstance(self, VirtualMachine) and self.data.role.name == "Router"
        ):
            for intfName, interface in self.interfaces.items():
                if interface["mgmtOnly"] or interface["isVirtual"]:
                    continue

                # Check if this address is public IP address (e.g. "8.8.8.8" on eth0)
                isExternalAddress = True
                for prefix in PrefixContainer().all():
                    for address in interface["addresses"]:
                        if address in netaddr.IPSet([prefix.subnet]):
                            isExternalAddress = False

                # If this interface has the public IP address, netplan shouldn't include it
                if isExternalAddress:
                    continue

                self.netplan_config["ethernets"].setdefault(intfName, {})
                self.netplan_config["ethernets"][intfName].setdefault(
                    "addresses", []
                ).extend(interface["addresses"])

        # If the current selected device is a Server
        elif isinstance(self, Device) and self.data.device_role.name == "Server":
            if primary_if:
                self.netplan_config["ethernets"][primary_if.name] = {
                    "dhcp4": "yes",
                    "dhcp4-overrides": {"route-metric": 100},
                }

            for intfName, interface in self.interfaces.items():
                if (
                    not interface["isVirtual"]
                    and intfName != primary_if.name
                    and not interface["mgmtOnly"]
                    and interface["addresses"]
                ):
                    self.netplan_config["ethernets"][intfName] = {
                        "dhcp4": "yes",
                        "dhcp4-overrides": {"route-metric": 200},
                    }

        else:
            # Exclude the device type which is not Router and Server
            return None

        # Get interfaces own by AssignedObject and is virtual (VLAN interface)
        for intfName, interface in self.interfaces.items():

            # If the interface is not a virtual interface or
            # the interface doesn't have VLAN tagged, skip this interface
            if not interface["isVirtual"] or not interface["instance"].tagged_vlans:
                continue

            if "vlans" not in self.netplan_config:
                self.netplan_config["vlans"] = dict()

            vlan_object_id = interface["instance"].tagged_vlans[0].id
            vlan_object = self.nbapi.ipam.vlans.get(vlan_object_id)

            routes = list()
            for address in interface["addresses"]:

                for reserved_ip in PrefixContainer().all_reserved_ips(address):

                    destination = reserved_ip["custom_fields"].get("rfc3442routes", "")
                    if not destination:
                        continue

                    for dest_addr in destination.split(","):

                        # If interface address is in destination subnet, we don't need this route
                        if netaddr.IPNetwork(address).ip in netaddr.IPNetwork(
                            dest_addr
                        ):
                            continue

                        new_route = {
                            "to": dest_addr,
                            "via": str(netaddr.IPNetwork(reserved_ip["ip4"]).ip),
                            "metric": 100,
                        }

                        if new_route not in routes:
                            routes.append(new_route)

            self.netplan_config["vlans"][intfName] = {
                "id": vlan_object.vid,
                "link": interface["instance"].label,
                "addresses": interface["addresses"],
            }

            # Only the fabric virtual interface will need to route to other network segments
            if routes and "fab" in intfName:
                self.netplan_config["vlans"][intfName]["routes"] = routes

        return self.netplan_config

    def generate_nftables(self):

        ret = dict()

        internal_if = None
        external_if = None

        # Use isPrimary == True as the identifier to select external interface
        for interface in self.interfaces.values():
            if interface["isPrimary"] is True:
                external_if = interface["instance"]

        if external_if is None:
            logger.error("The primary interface wasn't set for device %s", self.name)
            sys.exit(1)

        for interface in self.interfaces.values():
            # If "isVirtual" set to False and "mgmtOnly" set to False
            if (
                not interface["isVirtual"]
                and not interface["mgmtOnly"]
                and interface["instance"] is not external_if
            ):
                internal_if = interface["instance"]
                break

        ret["external_if"] = external_if.name
        ret["internal_if"] = internal_if.name

        if self.services:
            ret["services"] = list()

        for service in self.services:
            ret["services"].append(
                {
                    "name": service.name,
                    "protocol": service.protocol.value,
                    "port": service.port,
                }
            )

        # Only management server needs to be configured the whitelist netrange of
        # internal interface
        if self.data.device_role.name == "Router":

            ret["interface_subnets"] = dict()
            ret["ue_routing"] = dict()
            ret["ue_routing"]["ue_subnets"] = self.data.config_context.pop("ue_subnets")

            # Create the interface_subnets in the configuration
            # It's using the interface as the key to list IP addresses
            for intfName, interface in self.interfaces.items():
                if interface["mgmtOnly"]:
                    continue

                for address in interface["addresses"]:
                    for prefix in PrefixContainer().all():
                        intfAddr = netaddr.IPNetwork(address).ip

                        # If interface IP doesn't belong to this prefix, skip
                        if intfAddr not in netaddr.IPNetwork(prefix.subnet):
                            continue

                        # If prefix is a parent prefix (parent prefix won't config domain name)
                        # skip to add in interface_subnets
                        if not prefix.data.description:
                            continue

                        ret["interface_subnets"].setdefault(intfName, list())

                        if prefix.subnet not in ret["interface_subnets"][intfName]:
                            ret["interface_subnets"][intfName].append(prefix.subnet)
                        for neighbor in prefix.neighbor:
                            if (
                                neighbor.subnet
                                not in ret["interface_subnets"][intfName]
                            ):
                                ret["interface_subnets"][intfName].append(
                                    neighbor.subnet
                                )

            for prefix in PrefixContainer().all():

                if "fab" in prefix.data.description:
                    ret["ue_routing"].setdefault("src_subnets", [])
                    ret["ue_routing"]["src_subnets"].append(prefix.data.prefix)

                if (
                    not ret["ue_routing"].get("snat_addr")
                    and "fab" in prefix.data.description
                ):
                    for interface in self.interfaces.values():
                        for address in interface["addresses"]:
                            if address in netaddr.IPSet([prefix.subnet]):
                                ret["ue_routing"]["snat_addr"] = str(
                                    netaddr.IPNetwork(address).ip
                                )
                                break

        return ret

    def generate_extra_config(self):
        """
        Generate the extra configs which need in management server configuration
        This function should only be called when the device role is "Router"

        Extra config includes: service configuring parameters, additional config context
        """

        if self.extra_config:
            return self.extra_config

        service_names = list(map(lambda x: x.name, self.services))

        if "dns" in service_names:
            unbound_listen_ips = []
            unbound_allow_ips = []

            for interface in self.interfaces.values():
                if not interface["isPrimary"] and not interface["mgmtOnly"]:
                    for address in interface["addresses"]:
                        unbound_listen_ips.append(address)

            for prefix in PrefixContainer().all():
                if prefix.data.description:
                    unbound_allow_ips.append(prefix.data.prefix)

            if unbound_listen_ips:
                self.extra_config["unbound_listen_ips"] = unbound_listen_ips

            if unbound_allow_ips:
                self.extra_config["unbound_allow_ips"] = unbound_allow_ips

        if "ntp" in service_names:
            ntp_client_allow = []

            for prefix in PrefixContainer().all():
                if prefix.data.description:
                    ntp_client_allow.append(prefix.data.prefix)

            if ntp_client_allow:
                self.extra_config["ntp_client_allow"] = ntp_client_allow

        # If the key exists in generated config, warning with the key name
        for key in self.data.config_context.keys():
            if key in self.extra_config:
                logger.warning("Extra config Key %s was overwritten", key)

        self.extra_config.update(self.data.config_context)

        return self.extra_config


class Device(AssignedObject):
    """
    Wraps a single Netbox device
    Also caches all known devices in a class variable (devs)
    """

    def __init__(self, data):

        super().__init__(data)
        DeviceContainer().add(self.id, self)

    @property
    def type(self):
        return "Device"

    def get_interfaces(self):
        if not self.interfaces:
            self.interfaces = self.nbapi.dcim.interfaces.filter(device_id=self.id)

        return self.interfaces


class VirtualMachine(AssignedObject):
    """
    VM equivalent of Device
    """

    def __init__(self, data):

        super().__init__(data)
        VirtualMachineContainer().add(self.id, self)

    @property
    def type(self):
        return "VirtualMachine"

    def get_interfaces(self):
        if not self.interfaces:
            self.interfaces = self.nbapi.virtualization.interfaces.filter(
                virtual_machine_id=self.id
            )

        return self.interfaces
