#!/usr/bin/env python3

# SPDX-FileCopyrightText: © 2020 Open Networking Foundation <support@opennetworking.org>
# SPDX-License-Identifier: Apache-2.0

# TODO:
#  Fix issues where IPMI given primary IP for a node

from __future__ import absolute_import

import argparse
import json
import logging
import netaddr
import re
import ssl
import urllib.parse
import urllib.request
from ruamel import yaml

# create shared logger
logging.basicConfig()
logger = logging.getLogger("nbht")

# global dict of jsonpath expressions -> compiled jsonpath parsers, as
# reparsing expressions in each loop results in 100x longer execution time
jpathexpr = {}

# headers to pass, set globally
headers = []

# settings
settings = {}

# cached data from API
devices = {}
interfaces = {}


def parse_nb_args():
    """
    parse CLI arguments
    """

    parser = argparse.ArgumentParser(description="NetBox Host Descriptions")

    # Positional args
    parser.add_argument(
        "settings",
        type=argparse.FileType("r"),
        help="YAML ansible inventory file w/netbox info",
    )

    parser.add_argument(
        "--debug", action="store_true", help="Print additional debugging information"
    )

    return parser.parse_args()


def json_api_get(
    url,
    headers,
    data=None,
    trim_prefix=False,
    allow_failure=False,
    validate_certs=False,
):
    """
    Call JSON API endpoint, return data as a dict
    """

    logger.debug("json_api_get url: %s", url)

    # if data included, encode it as JSON
    if data:
        data_enc = str(json.dumps(data)).encode("utf-8")

        request = urllib.request.Request(url, data=data_enc, method="POST")
        request.add_header("Content-Type", "application/json; charset=UTF-8")
    else:
        request = urllib.request.Request(url)

    # add headers tuples
    for header in headers:
        request.add_header(*header)

    try:

        if validate_certs:
            response = urllib.request.urlopen(request)

        else:
            ctx = ssl.create_default_context()
            ctx.check_hostname = False
            ctx.verify_mode = ssl.CERT_NONE

            response = urllib.request.urlopen(request, context=ctx)

    except urllib.error.HTTPError:
        # asking for data that doesn't exist results in a 404, just return nothing
        if allow_failure:
            return None
        logger.exception("Server encountered an HTTPError at URL: '%s'", url)
    except urllib.error.URLError:
        logger.exception("An URLError occurred at URL: '%s'", url)
    else:
        # docs: https://docs.python.org/3/library/json.html
        jsondata = response.read()
        logger.debug("API response: %s", jsondata)

    try:
        data = json.loads(jsondata)
    except json.decoder.JSONDecodeError:
        # allow return of no data
        if allow_failure:
            return None
        logger.exception("Unable to decode JSON")
    else:
        logger.debug("JSON decoded: %s", data)

    return data


def create_dns_zone(extension, devs):
    # Checks for dns entries

    a_recs = {}  # PTR records created by inverting this
    cname_recs = {}
    srv_recs = {}
    ns_recs = []
    txt_recs = {}

    # scan through devs and look for dns_name, if not, make from name and
    # extension
    for name, value in devs.items():

        # add DNS entries for every DHCP host if there's a DHCP range
        # DHCP addresses are of the form dhcp###.extension
        if name == "prefix_dhcp":
            for ip in netaddr.IPNetwork(value["dhcp_range"]).iter_hosts():
                a_recs["dhcp%03d" % (ip.words[3])] = str(ip)

            continue

        # require DNS names to only use ASCII characters (alphanumeric, lowercase, with dash/period)
        # _'s are used in SRV/TXT records, but in general use aren't recommended
        dns_name = re.sub("[^a-z0-9.-]", "-", name.lower(), 0, re.ASCII)

        # Add as an A record (and inverse, PTR record), only if it's a new name
        if dns_name not in a_recs:
            a_recs[dns_name] = value["ip4"]
        else:
            # most likely a data entry error
            logger.warning(
                "Duplicate DNS name '%s' for devices at IP: '%s' and '%s', ignoring",
                dns_name,
                a_recs[dns_name],
                value["ip4"],
            )
            continue

        # if a DNS name is given as a part of the IP address, it's viewed as a CNAME
        if value["dns_name"]:

            if re.search("%s$" % extension, value["dns_name"]):

                # strip off the extension, and add as a CNAME
                dns_cname = value["dns_name"].split(".%s" % extension)[0]

            elif "." in value["dns_name"]:
                logger.warning(
                    "Device '%s' has a IP assigned DNS name '%s' outside the prefix extension: '%s', ignoring",
                    name,
                    value["dns_name"],
                    extension,
                )
                continue

            else:
                dns_cname = value["dns_name"]

            if dns_cname == dns_name:
                logger.warning(
                    "DNS Name field '%s' is identical to device name '%s', ignoring",
                    value["dns_name"],
                    dns_name,
                )
            else:
                cname_recs[dns_cname] = "%s.%s." % (dns_name, extension)

        # Add services as cnames, and possibly ns records
        for svc in value["services"]:

            # only add service if it uses the IP of the host
            if value["ip4"] in svc["ip4s"]:
                cname_recs[svc["name"]] = "%s.%s." % (dns_name, extension)

            if svc["port"] == 53 and svc["protocol"] == "udp":
                ns_recs.append("%s.%s." % (dns_name, extension))

    return {
        "a": a_recs,
        "cname": cname_recs,
        "ns": ns_recs,
        "srv": srv_recs,
        "txt": txt_recs,
    }


def create_dhcp_subnet(devs):
    # makes DHCP subnet information

    hosts = {}

    for name, value in devs.items():

        # has a MAC address, and it's not null
        if "macaddr" in value and value["macaddr"]:

            hosts[value["ip4"]] = {
                "name": name,
                "macaddr": value["macaddr"],
            }

    return hosts


def get_device_services(device_id, filters=""):

    # get services info
    url = "%s%s" % (
        settings["api_endpoint"],
        "api/ipam/services/?device_id=%s%s" % (device_id, filters),
    )

    raw_svcs = json_api_get(url, headers, validate_certs=settings["validate_certs"])

    services = []

    for rsvc in raw_svcs["results"]:

        svc = {}

        svc["name"] = rsvc["name"]
        svc["description"] = rsvc["description"]
        svc["port"] = rsvc["port"]
        svc["protocol"] = rsvc["protocol"]["value"]
        svc["ip4s"] = []

        for ip in rsvc["ipaddresses"]:
            svc["ip4s"].append(str(netaddr.IPNetwork(ip["address"]).ip))

        services.append(svc)

    return services


def get_interface_mac_addr(interface_id):
    # return a mac addres, or None if undefined

    # get the interface info
    url = "%s%s" % (settings["api_endpoint"], "api/dcim/interfaces/%s/" % interface_id)

    iface = json_api_get(url, headers, validate_certs=settings["validate_certs"])

    if iface["mac_address"]:
        return iface["mac_address"]

    return None


def get_device_interfaces(device_id, filters=""):

    url = "%s%s" % (
        settings["api_endpoint"],
        "api/dcim/interfaces/?device_id=%s%s" % (device_id, filters),
    )

    logger.debug("raw_ifaces_url: %s", url)

    raw_ifaces = json_api_get(url, headers, validate_certs=settings["validate_certs"])

    logger.debug("raw_ifaces: %s", raw_ifaces)

    ifaces = []

    for raw_iface in raw_ifaces["results"]:

        iface = {}

        iface["name"] = raw_iface["name"]
        iface["macaddr"] = raw_iface["mac_address"]
        iface["mgmt_only"] = raw_iface["mgmt_only"]
        iface["description"] = raw_iface["description"]

        if raw_iface["count_ipaddresses"]:
            url = "%s%s" % (
                settings["api_endpoint"],
                "api/ipam/ip-addresses/?interface_id=%s" % raw_iface["id"],
            )

            raw_ip = json_api_get(
                url, headers, validate_certs=settings["validate_certs"]
            )

            iface["ip4"] = str(netaddr.IPNetwork(raw_ip["results"][0]["address"]).ip)

        ifaces.append(iface)

    return ifaces


def get_prefix_devices(prefix, filters=""):

    # get all devices in a prefix
    url = "%s%s" % (
        settings["api_endpoint"],
        "api/ipam/ip-addresses/?parent=%s%s" % (prefix, filters),
    )

    raw_ips = json_api_get(url, headers, validate_certs=settings["validate_certs"])

    logger.debug("raw_ips: %s", raw_ips)

    devs = {}

    for ip in raw_ips["results"]:

        logger.info("ip: %s", ip)

        # if it's a DHCP range, add that range to the dev list as prefix_dhcp
        if ip["status"]["value"] == "dhcp":
            devs["prefix_dhcp"] = {"dhcp_range": ip["address"]}
            continue

        dev = {}

        dev["ip4"] = str(netaddr.IPNetwork(ip["address"]).ip)
        dev["macaddr"] = get_interface_mac_addr(ip["assigned_object"]["id"])

        ifaces = get_device_interfaces(
            ip["assigned_object"]["device"]["id"], "&mgmt_only=true"
        )

        if ifaces and dev["ip4"] == ifaces[0]["ip4"]:  # this is a mgmt IP
            devname = "%s-%s" % (
                ip["assigned_object"]["device"]["name"],
                ifaces[0]["name"],
            )
            dev["dns_name"] = ""
            dev["services"] = []

        else:  # this is a primary IP

            devname = ip["assigned_object"]["device"]["name"]
            dev["dns_name"] = ip["dns_name"] if "dns_name" in ip else "None"
            dev["services"] = get_device_services(ip["assigned_object"]["device"]["id"])

        devs[devname] = dev

    return devs


def get_prefix_data(prefix):

    # get all devices in a prefix
    url = "%s%s" % (settings["api_endpoint"], "api/ipam/prefixes/?prefix=%s" % prefix)

    raw_prefix = json_api_get(url, headers, validate_certs=settings["validate_certs"])

    logger.debug("raw_prefix: %s", raw_prefix)

    return raw_prefix["results"][0]


# main function that calls other functions
if __name__ == "__main__":

    args = parse_nb_args()

    # only print log messages if debugging
    if args.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    # load settings from yaml file
    settings = yaml.safe_load(args.settings.read())

    logger.info("settings: %s" % settings)

    # global, so this isn't run multiple times
    headers = [
        ("Authorization", "Token %s" % settings["token"]),
    ]

    # create structure from extracted data

    dns_global = {}
    dns_zones = {}
    dhcp_global = {}
    dhcp_subnets = {}

    for prefix in settings["dns_prefixes"]:

        prefix_data = get_prefix_data(prefix)

        prefix_domain_extension = prefix_data["description"]

        devs = get_prefix_devices(prefix)

        dns_zones[prefix_domain_extension] = create_dns_zone(
            prefix_domain_extension, devs
        )

        dns_zones[prefix_domain_extension]["ip_range"] = prefix

        dhcp_subnets[prefix] = create_dhcp_subnet(devs)

    yaml_out = {
        "dns_global": dns_global,
        "dns_zones": dns_zones,
        "dhcp_global": dhcp_global,
        "dhcp_subnets": dhcp_subnets,
        "devs": devs,
        "prefix_data": prefix_data,
    }

    print(yaml.safe_dump(yaml_out, indent=2))
