# Copyright 2017-present Open Networking Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import print_function
import plyxproto.model as m
from plyxproto.helpers import Visitor
import argparse
import plyxproto.parser as plyxproto
import traceback
import sys
import jinja2
import os
import copy
import pdb


class MissingPolicyException(Exception):
    pass


def find_missing_policy_calls(name, policies, policy):
    if isinstance(policy, dict):
        (k, lst), = policy.items()
        if k == "policy":
            policy_name = lst[0]
            if policy_name not in policies:
                raise MissingPolicyException(
                    "Policy %s invoked missing policy %s" % (name, policy_name)
                )
        else:
            for p in lst:
                find_missing_policy_calls(name, policies, p)
    elif isinstance(policy, list):
        for p in lst:
            find_missing_policy_calls(name, policies, p)


def dotname_to_fqn(dotname):
    b_names = [part.pval for part in dotname]
    package = ".".join(b_names[:-1])
    name = b_names[-1]
    if package:
        fqn = package + "." + name
    else:
        fqn = name
    return {"name": name, "fqn": fqn, "package": package}


def dotname_to_name(dotname):
    b_names = [part.pval for part in dotname]
    return ".".join(b_names)


def count_messages(body):
    count = 0
    for e in body:
        if isinstance(e, m.MessageDefinition):
            count += 1
    return count


def count_fields(body):
    count = 0
    for e in body:
        if type(e) in [m.LinkDefinition, m.FieldDefinition, m.LinkSpec]:
            count += 1
    return count


def name_to_value(obj):
    try:
        value = obj.value.value.pval
    except AttributeError:
        try:
            value = obj.value.value
        except AttributeError:
            value = obj.value.pval

    return value


class Stack(list):
    def push(self, x):
        self.append(x)


""" XOS2Jinja overrides the underlying visitor pattern to transform the tree
    in addition to traversing it """


class XOS2Jinja(Visitor):
    def __init__(self, args):
        super(XOS2Jinja, self).__init__()

        self.stack = Stack()
        self.models = {}
        self.options = {}
        self.package = None
        self.message_options = {}
        self.count_stack = Stack()
        self.policies = {}
        self.content = ""
        self.offset = 0
        self.current_message_name = None
        self.verbose = 0
        self.first_field = True
        self.first_method = True
        self.args = args

    def visit_PolicyDefinition(self, obj):
        if self.package:
            pname = ".".join([self.package, obj.name.value.pval])
        else:
            pname = obj.name.value.pval

        self.policies[pname] = obj.body
        find_missing_policy_calls(pname, self.policies, obj.body)

        return True

    def visit_PackageStatement(self, obj):
        dotlist = obj.name.value
        dotlist2 = [f.pval for f in dotlist]
        dotstr = ".".join(dotlist2)
        self.package = dotstr
        return True

    def visit_ImportStatement(self, obj):
        """Ignore"""
        return True

    def visit_OptionStatement(self, obj):
        if not hasattr(obj, "mark_for_deletion"):
            if self.current_message_name:
                self.message_options[obj.name.value.pval] = obj.value.value.pval
            else:
                self.options[obj.name.value.pval] = obj.value.value.pval

        return True

    def visit_LU(self, obj):
        return True

    def visit_default(self, obj):
        return True

    def visit_FieldDirective(self, obj):
        return True

    def visit_FieldDirective_post(self, obj):

        try:
            name = obj.name.value.pval
        except AttributeError:
            name = obj.name.value

        if isinstance(obj.value, list):
            value = dotname_to_name(obj.value)
        else:
            value = name_to_value(obj)

        self.stack.push([name, value])
        return True

    def visit_FieldType(self, obj):
        """Field type, if type is name, then it may need refactoring consistent with refactoring rules according to the table"""
        return True

    def visit_LinkDefinition(self, obj):
        s = {}

        try:
            s["link_type"] = obj.link_type.pval
        except AttributeError:
            s["link_type"] = obj.link_type

        s["src_port"] = obj.src_port.value.pval
        s["name"] = obj.src_port.value.pval
        try:
            s["policy"] = obj.policy.pval
        except AttributeError:
            s["policy"] = None

        try:
            s["dst_port"] = obj.dst_port.value.pval
        except AttributeError:
            s["dst_port"] = obj.dst_port

        if isinstance(obj.through, list):
            s["through"] = dotname_to_fqn(obj.through)
        else:
            try:
                s["through"] = obj.through.pval
            except AttributeError:
                s["through"] = obj.through

        if isinstance(obj.name, list):
            s["peer"] = dotname_to_fqn(obj.name)
        else:
            try:
                s["peer"] = obj.name.pval
            except AttributeError:
                s["peer"] = obj.name

        try:
            s["reverse_id"] = obj.reverse_id.pval
        except AttributeError:
            s["reverse_id"] = obj.reverse_id

        s["_type"] = "link"
        s["options"] = {"modifier": "optional"}

        self.stack.push(s)
        return True

    def visit_FieldDefinition(self, obj):
        self.count_stack.push(len(obj.fieldDirective))
        return True

    def visit_FieldDefinition_post(self, obj):
        s = {}

        if isinstance(obj.ftype, m.Name):
            s["type"] = obj.ftype.value
        else:
            s["type"] = obj.ftype.name.pval

        s["name"] = obj.name.value.pval

        try:
            s["policy"] = obj.policy.pval
        except AttributeError:
            s["policy"] = None

        s["modifier"] = obj.field_modifier.pval
        s["id"] = obj.fieldId.pval

        opts = {"modifier": s["modifier"]}
        n = self.count_stack.pop()
        for i in range(0, n):
            k, v = self.stack.pop()

            # The two lines below may be added to eliminate "" around an option.
            # Right now, this is handled in targets. FIXME
            #
            # if (v.startswith('"') and v.endswith('"')):
            #    v = v[1:-1]

            opts[k] = v

        s["options"] = opts
        try:
            last_link = self.stack[-1]["_type"]
            if last_link == "link":
                s["link"] = True
        except BaseException:
            pass
        s["_type"] = "field"

        self.stack.push(s)
        return True

    def visit_EnumFieldDefinition(self, obj):
        if self.verbose > 4:
            print("\tEnumField: name=%s, %s" % (obj.name, obj))

        return True

    def visit_EnumDefinition(self, obj):
        """New enum definition, refactor name"""
        if self.verbose > 3:
            print("Enum, [%s] body=%s\n\n" % (obj.name, obj.body))

        return True

    def visit_MessageDefinition(self, obj):
        self.current_message_name = obj.name.value.pval
        self.message_options = {}
        self.count_stack.push(count_fields(obj.body))
        return True

    def visit_MessageDefinition_post(self, obj):
        stack_num = self.count_stack.pop()
        fields = []
        links = []
        last_field = None
        try:
            obj.bases = map(dotname_to_fqn, obj.bases)
        except AttributeError:
            pass

        last_field = {}
        for i in range(0, stack_num):
            f = self.stack.pop()
            if f["_type"] == "link":
                f["options"] = {
                    i: d[i] for d in [f["options"], last_field["options"]] for i in d
                }
                assert last_field == fields[0]
                fields[0].setdefault("options", {})["link_type"] = f["link_type"]
                links.insert(0, f)
            else:
                fields.insert(0, f)
                last_field = f

        if self.package:
            model_name = ".".join([self.package, obj.name.value.pval])
        else:
            model_name = obj.name.value.pval

        model_def = {
            "name": obj.name.value.pval,
            "fields": fields,
            "links": links,
            "bases": obj.bases,
            "options": self.message_options,
            "package": self.package,
            "fqn": model_name,
            "rlinks": [],
        }
        try:
            model_def["policy"] = obj.policy.pval
        except AttributeError:
            model_def["policy"] = None

        self.stack.push(model_def)

        self.models[model_name] = model_def

        # Set message options
        for k, v in self.options.iteritems():
            try:
                if k not in self.message_options:
                    self.message_options[k] = v
            except KeyError:
                pass

        self.current_message_name = None
        return True

    def visit_MessageExtension(self, obj):
        return True

    def visit_MethodDefinition(self, obj):
        return True

    def visit_ServiceDefinition(self, obj):
        return True

    def visit_ExtensionsDirective(self, obj):
        return True

    def visit_Literal(self, obj):
        return True

    def visit_Name(self, obj):
        return True

    def visit_DotName(self, obj):
        return True

    def visit_Proto(self, obj):
        self.count_stack.push(count_messages(obj.body))
        return True

    def visit_Proto_post(self, obj):
        count = self.count_stack.pop()
        messages = []
        for i in range(0, count):
            try:
                m = self.stack.pop()
            except IndexError:
                pass

            messages.insert(0, m)

        self.compute_rlinks(messages, self.models)

        self.messages = messages
        return True

    def visit_LinkSpec(self, obj):
        count = self.count_stack.pop()
        self.count_stack.push(count + 1)
        return True

    def compute_rlinks(self, messages, message_dict):
        rev_links = {}

        link_opposite = {
            "manytomany": "manytomany",
            "manytoone": "onetomany",
            "onetoone": "onetoone",
            "onetomany": "manytoone",
        }

        for m in messages:
            for l in m["links"]:
                rlink = copy.deepcopy(l)

                rlink["_type"] = "rlink"  # An implicit link, not declared in the model
                rlink["src_port"] = l["dst_port"]
                rlink["dst_port"] = l["src_port"]
                rlink["peer"] = {
                    "name": m["name"],
                    "package": m["package"],
                    "fqn": m["fqn"],
                }
                rlink["link_type"] = link_opposite[l["link_type"]]
                rlink["reverse_id"] = l["reverse_id"]

                if (not l["reverse_id"]) and (self.args.verbosity >= 1):
                    print(
                        "WARNING: Field %s in model %s has no reverse_id"
                        % (l["src_port"], m["name"]),
                        file=sys.stderr,
                    )

                if l["reverse_id"] and (
                    (int(l["reverse_id"]) < 1000) or (int(l["reverse_id"]) >= 1900)
                ):
                    raise Exception(
                        "reverse id for field %s in model %s should be between 1000 and 1899"
                        % (l["src_port"], m["name"])
                    )

                try:
                    try:
                        rev_links[l["peer"]["fqn"]].append(rlink)
                    except TypeError:
                        pass
                except KeyError:
                    rev_links[l["peer"]["fqn"]] = [rlink]

        for m in messages:
            try:
                m["rlinks"] = rev_links[m["name"]]
                message_dict[m["name"]]["rlinks"] = m["rlinks"]
            except KeyError:
                pass
