CORD-2357 manytomany support

Change-Id: I54debd4eab66df003dc5079890c3fc87ee0d3e80
diff --git a/lib/xos-genx/xosgenx/targets/protoapi.xtarget b/lib/xos-genx/xosgenx/targets/protoapi.xtarget
index f375dae..c580764 100644
--- a/lib/xos-genx/xosgenx/targets/protoapi.xtarget
+++ b/lib/xos-genx/xosgenx/targets/protoapi.xtarget
@@ -21,9 +21,13 @@
     {%- endif %}
     {%- set id_field = {'type':'int32', 'name':'id', 'options':{}} -%}
   {%- for field in (xproto_base_fields(object, proto.message_table) + object.fields + [id_field]) | sort(attribute='name')%}
+  {%- if field.options.type == "link" and field.options.link_type == "manytomany" %}
+    repeated int32 {{ field.name }}_ids = {{ loop.index }} [(manyToManyForeignKey).modelName = "{{ field.options.model }}"];
+  {%- else %}
     oneof {{ field.name }}_present {
       {{ xproto_api_type(field) }} {{ field.name }}{% if field.link -%}_id{% endif %} = {{ loop.index }}{{ xproto_api_opts(field) }};
     }
+  {%- endif -%}
   {%- endfor -%}
 
   {%- for ref in xproto_base_rlinks(object, proto.message_table) + object.rlinks | sort(attribute='src_port') %}
diff --git a/xos/coreapi/apihelper.py b/xos/coreapi/apihelper.py
index f8ea699..768520c 100644
--- a/xos/coreapi/apihelper.py
+++ b/xos/coreapi/apihelper.py
@@ -251,6 +251,26 @@
                     continue
                 getattr(p_obj, related_name + "_ids").append(rel_obj.id)
 
+        # Go through any many-to-many relations. This is almost the same as the related_objects loop above, but slightly
+        # different due to how django handles m2m.
+
+        for m2m in obj._meta.many_to_many:
+            related_name = m2m.name
+            if not related_name:
+                continue
+            if "+" in related_name:   # duplicated logic from related_objects; not sure if necessary
+                continue
+
+            rel_objs = getattr(obj, related_name)
+
+            if not hasattr(rel_objs, "all"):
+                continue
+
+            for rel_obj in rel_objs.all():
+                if not hasattr(p_obj, related_name + "_ids"):
+                    continue
+                getattr(p_obj, related_name + "_ids").append(rel_obj.id)
+
         # Generate a list of class names for the object. This includes its
         # ancestors. Anything that is a descendant of XOSBase or User
         # counts.
@@ -309,6 +329,56 @@
 
         return args
 
+    def handle_m2m(self, djangoClass, message, update_fields):
+        # fix for possible django bug?
+        # Unless we refresh the object, django will ignore every other m2m save
+
+        #djangoClass = djangoClass.__class__.objects.get(id=djangoClass.id)
+        djangoClass.refresh_from_db()
+
+        fmap={}
+        for m2m in djangoClass._meta.many_to_many:
+            related_name = m2m.name
+            if not related_name:
+                continue
+            if "+" in related_name:   # duplicated logic from related_objects; not sure if necessary
+                continue
+
+            fmap[m2m.name + "_ids"] = m2m
+
+        fields_changed = []
+        for (fieldDesc, val) in message.ListFields():
+            if fieldDesc.name in fmap:
+                m2m = getattr(djangoClass,fmap[fieldDesc.name].name)
+
+                # remove items that are in the django object, but not in the proto object
+                for item in list(m2m.all()):
+                    if (not item.id in val):
+                        m2m.remove(item.id)
+                        fields_changed.append(fieldDesc.name)
+
+                # add items are are in the proto object, but not in the django object
+                django_ids = [x.id for x in m2m.all()]
+
+                for item in val:
+                    if item not in django_ids:
+                        m2m.add(item)
+                        fields_changed.append(fieldDesc.name)
+
+        # gRPC doesn't give us a convenient way to differentiate between an empty list and an omitted list. So what
+        # we'll do is check and see if the user specified a fieldname in `update_fields`. If the user did, and that
+        # field is an m2m that we didn't encounter, then it must have been an empty list that the user wants
+        # to set.
+
+        for name in update_fields:
+            if (name in fmap) and (not name in fields_changed):
+                m2m = getattr(djangoClass, fmap[name].name)
+                m2m.clear()
+                fields_changed.append(name)
+
+        if fields_changed:
+            djangoClass.save()
+
     def querysetToProto(self, djangoClass, queryset):
         objs = queryset
         p_objs = self.getPluralProtoClass(djangoClass)()
@@ -393,6 +463,9 @@
         self.xos_security_gate(new_obj, user, write_access=True)
 
         new_obj.save()
+
+        self.handle_m2m(new_obj, request, [])
+
         return self.objToProto(new_obj)
 
     def update(self, djangoClass, user, id, message, context):
@@ -405,16 +478,28 @@
         for (k, v) in args.iteritems():
             setattr(obj, k, v)
 
+        m2m_field_names = [x.name+"_ids" for x in djangoClass._meta.many_to_many]
+
+        update_fields = []
+        m2m_update_fields = []
         save_kwargs = {}
         for (k, v) in context.invocation_metadata():
             if k == "update_fields":
-                save_kwargs["update_fields"] = v.split(",")
+                for field_name in v.split(","):
+                    if field_name in m2m_field_names:
+                        m2m_update_fields.append(field_name)
+                    else:
+                        update_fields.append(field_name)
+                save_kwargs["update_fields"] = update_fields
             elif k == "caller_kind":
                 save_kwargs["caller_kind"] = v
             elif k == "always_update_timestamp":
                 save_kwargs["always_update_timestamp"] = True
 
         obj.save(**save_kwargs)
+
+        self.handle_m2m(obj, message, m2m_update_fields)
+
         return self.objToProto(obj)
 
     def delete(self, djangoClass, user, id):
diff --git a/xos/coreapi/protos/xosoptions.proto b/xos/coreapi/protos/xosoptions.proto
index 7602c95..008f3de 100644
--- a/xos/coreapi/protos/xosoptions.proto
+++ b/xos/coreapi/protos/xosoptions.proto
@@ -18,10 +18,15 @@
   string modelName = 1;
 }
 
+message ManyToManyForeignKeyRule {
+  string modelName = 1;
+}
+
 extend google.protobuf.FieldOptions {
   ValRule val = 1001;
   ForeignKeyRule foreignKey = 1002;
   ReverseForeignKeyRule reverseForeignKey = 1003;
+  ManyToManyForeignKeyRule manyToManyForeignKey = 1004;
 }
 
 extend google.protobuf.MessageOptions {
diff --git a/xos/xos_client/tests/orm_nodelabel.py b/xos/xos_client/tests/orm_nodelabel.py
new file mode 100644
index 0000000..89dd025
--- /dev/null
+++ b/xos/xos_client/tests/orm_nodelabel.py
@@ -0,0 +1,163 @@
+
+# Copyright 2017-present Open Networking Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# These are functional tests of ManyToMany relations. These tests need to be conducted end-to-end with a real
+# API to verify that the client and server ends of the API are working with each other.
+
+import random
+import string
+import sys
+import unittest
+
+orm = None
+
+from xosapi import xos_grpc_client
+
+TEST_NODE_LABEL_1_NAME = "test_node_label_1"
+
+class TestORM(unittest.TestCase):
+    def setUp(self):
+        self.test_node_label_1_name = TEST_NODE_LABEL_1_NAME + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
+
+        nodes1 = orm.Node.objects.filter(name="test_node_1")
+        if nodes1:
+            self.node1=nodes1[0]
+        else:
+            self.node1 = orm.Node(name="test_node_1", site_deployment=orm.SiteDeployment.objects.first())
+            self.node1.save()
+
+        nodes2 = orm.Node.objects.filter(name="test_node_2")
+        if nodes2:
+            self.node2=nodes2[0]
+        else:
+            self.node2 = orm.Node(name="test_node_2", site_deployment=orm.SiteDeployment.objects.first())
+            self.node2.save()
+
+    def tearDown(self):
+        # TODO: Deleting NodeLabel seems to be broken -- appears to be a cascade failure
+        # attaching a nodelabel to a node causes deleting the node to also be broken.
+
+
+        #node_labels1 = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+        #for node_label in node_labels1:
+        #    node_label.delete()
+
+        #nodes1 = orm.Node.objects.filter(name="test_node_1")
+        #for node in nodes1:
+        #    node.delete()
+
+        #nodes2 = orm.Node.objects.filter(name="test_node_2")
+        #for node in nodes2:
+        #    node.delete()
+
+        pass
+
+    def test_create_empty_node_label(self):
+        n = orm.NodeLabel(name = self.test_node_label_1_name)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+        self.assertEqual(len(labels),1)
+
+        n=labels[0]
+        self.assertNotEqual(n, None)
+        self.assertEqual(len(n.node.all()), 0)
+
+    def test_create_node_label_one_node(self):
+        n = orm.NodeLabel(name = self.test_node_label_1_name)
+        n.node.add(self.node1)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+        self.assertEqual(len(labels),1)
+
+        n=labels[0]
+        self.assertNotEqual(n, None)
+        self.assertEqual(len(n.node.all()), 1)
+
+    def test_create_node_label_two_nodes(self):
+        n = orm.NodeLabel(name = self.test_node_label_1_name)
+        n.node.add(self.node1)
+        n.node.add(self.node2)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+        self.assertEqual(len(labels),1)
+
+        n=labels[0]
+        self.assertNotEqual(n, None)
+        self.assertEqual(len(n.node.all()), 2)
+
+    def test_add_node_to_label(self):
+        n = orm.NodeLabel(name = self.test_node_label_1_name)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n=labels[0]
+        n.node.add(self.node1)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n=labels[0]
+        self.assertEqual(len(n.node.all()), 1)
+
+    def test_remove_node_from_label(self):
+        n = orm.NodeLabel(name=self.test_node_label_1_name)
+        n.node.add(self.node1)
+        n.node.add(self.node2)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n = labels[0]
+        self.assertEqual(len(n.node.all()), 2)
+        n.node.remove(self.node1)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n = labels[0]
+        self.assertEqual(len(n.node.all()), 1)
+
+    def test_remove_last_node_from_label(self):
+        n = orm.NodeLabel(name=self.test_node_label_1_name)
+        n.node.add(self.node1)
+        n.save()
+
+        labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n = labels[0]
+        self.assertEqual(len(n.node.all()), 1)
+        n.node.remove(self.node1)
+        n.save(update_fields=["node_ids"])
+
+        labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+        self.assertEqual(len(labels), 1)
+        n = labels[0]
+        self.assertEqual(len(n.node.all()), 0)
+
+
+def test_callback():
+    global orm
+
+    orm = xos_grpc_client.coreclient.xos_orm
+
+    sys.argv=sys.argv[:1]  # unittest gets mad about the orm command line arguments
+    unittest.main()
+
+xos_grpc_client.start_api_parseargs(test_callback)
+
diff --git a/xos/xos_client/xosapi/orm.py b/xos/xos_client/xosapi/orm.py
index ed84f6e..77d3e7c 100644
--- a/xos/xos_client/xosapi/orm.py
+++ b/xos/xos_client/xosapi/orm.py
@@ -100,8 +100,14 @@
            if name.endswith("_ids"):
                reverseForeignKey = field.GetOptions().Extensions._FindExtensionByName("xos.reverseForeignKey")
                fk = field.GetOptions().Extensions[reverseForeignKey]
-               if fk:
-                   reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName}
+               if fk and fk.modelName:
+                   reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName, "writeable": False}
+               else:
+                   manyToManyForeignKey = field.GetOptions().Extensions._FindExtensionByName("xos.manyToManyForeignKey")
+                   fk = field.GetOptions().Extensions[manyToManyForeignKey]
+                   if fk and fk.modelName:
+                       reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName, "writeable": True}
+
 
         return reverse_fkmap
 
@@ -133,7 +139,7 @@
     def reverse_fk_resolve(self, name):
         if name not in self.reverse_cache:
             fk_entry = self._reverse_fkmap[name]
-            self.cache[name] = ORMLocalObjectManager(self.stub, fk_entry["modelName"], getattr(self, fk_entry["src_fieldName"]))
+            self.cache[name] = ORMLocalObjectManager(self.stub, fk_entry["modelName"], getattr(self, fk_entry["src_fieldName"]), fk_entry["writeable"])
 
         return self.cache[name]
 
@@ -284,10 +290,11 @@
 class ORMLocalObjectManager(object):
     """ Manages a local list of objects """
 
-    def __init__(self, stub, modelName, idList):
+    def __init__(self, stub, modelName, idList, writeable):
         self._stub = stub
         self._modelName = modelName
         self._idList = idList
+        self._writeable = writeable
         self._cache = None
 
     def resolve_queryset(self):
@@ -319,6 +326,32 @@
         else:
             return None
 
+    def add(self, model):
+        if not self._writeable:
+            raise Exception("Only ManyToMany lists are writeable")
+
+        if isinstance(model, int):
+            id = model
+        else:
+            if not model.id:
+                raise Exception("Model %s has no id" % model)
+            id = model.id
+
+        self._idList.append(id)
+
+    def remove(self, model):
+        if not self._writeable:
+            raise Exception("Only ManyToMany lists are writeable")
+
+        if isinstance(model, int):
+            id = model
+        else:
+            if not model.id:
+                raise Exception("Model %s has no id" % model)
+            id = model.id
+
+        self._idList.remove(id)
+
 class ORMObjectManager(object):
     """ Manages a remote list of objects """