CORD-2357 manytomany support
Change-Id: I54debd4eab66df003dc5079890c3fc87ee0d3e80
diff --git a/lib/xos-genx/xosgenx/targets/protoapi.xtarget b/lib/xos-genx/xosgenx/targets/protoapi.xtarget
index f375dae..c580764 100644
--- a/lib/xos-genx/xosgenx/targets/protoapi.xtarget
+++ b/lib/xos-genx/xosgenx/targets/protoapi.xtarget
@@ -21,9 +21,13 @@
{%- endif %}
{%- set id_field = {'type':'int32', 'name':'id', 'options':{}} -%}
{%- for field in (xproto_base_fields(object, proto.message_table) + object.fields + [id_field]) | sort(attribute='name')%}
+ {%- if field.options.type == "link" and field.options.link_type == "manytomany" %}
+ repeated int32 {{ field.name }}_ids = {{ loop.index }} [(manyToManyForeignKey).modelName = "{{ field.options.model }}"];
+ {%- else %}
oneof {{ field.name }}_present {
{{ xproto_api_type(field) }} {{ field.name }}{% if field.link -%}_id{% endif %} = {{ loop.index }}{{ xproto_api_opts(field) }};
}
+ {%- endif -%}
{%- endfor -%}
{%- for ref in xproto_base_rlinks(object, proto.message_table) + object.rlinks | sort(attribute='src_port') %}
diff --git a/xos/coreapi/apihelper.py b/xos/coreapi/apihelper.py
index f8ea699..768520c 100644
--- a/xos/coreapi/apihelper.py
+++ b/xos/coreapi/apihelper.py
@@ -251,6 +251,26 @@
continue
getattr(p_obj, related_name + "_ids").append(rel_obj.id)
+ # Go through any many-to-many relations. This is almost the same as the related_objects loop above, but slightly
+ # different due to how django handles m2m.
+
+ for m2m in obj._meta.many_to_many:
+ related_name = m2m.name
+ if not related_name:
+ continue
+ if "+" in related_name: # duplicated logic from related_objects; not sure if necessary
+ continue
+
+ rel_objs = getattr(obj, related_name)
+
+ if not hasattr(rel_objs, "all"):
+ continue
+
+ for rel_obj in rel_objs.all():
+ if not hasattr(p_obj, related_name + "_ids"):
+ continue
+ getattr(p_obj, related_name + "_ids").append(rel_obj.id)
+
# Generate a list of class names for the object. This includes its
# ancestors. Anything that is a descendant of XOSBase or User
# counts.
@@ -309,6 +329,56 @@
return args
+ def handle_m2m(self, djangoClass, message, update_fields):
+ # fix for possible django bug?
+ # Unless we refresh the object, django will ignore every other m2m save
+
+ #djangoClass = djangoClass.__class__.objects.get(id=djangoClass.id)
+ djangoClass.refresh_from_db()
+
+ fmap={}
+ for m2m in djangoClass._meta.many_to_many:
+ related_name = m2m.name
+ if not related_name:
+ continue
+ if "+" in related_name: # duplicated logic from related_objects; not sure if necessary
+ continue
+
+ fmap[m2m.name + "_ids"] = m2m
+
+ fields_changed = []
+ for (fieldDesc, val) in message.ListFields():
+ if fieldDesc.name in fmap:
+ m2m = getattr(djangoClass,fmap[fieldDesc.name].name)
+
+ # remove items that are in the django object, but not in the proto object
+ for item in list(m2m.all()):
+ if (not item.id in val):
+ m2m.remove(item.id)
+ fields_changed.append(fieldDesc.name)
+
+ # add items are are in the proto object, but not in the django object
+ django_ids = [x.id for x in m2m.all()]
+
+ for item in val:
+ if item not in django_ids:
+ m2m.add(item)
+ fields_changed.append(fieldDesc.name)
+
+ # gRPC doesn't give us a convenient way to differentiate between an empty list and an omitted list. So what
+ # we'll do is check and see if the user specified a fieldname in `update_fields`. If the user did, and that
+ # field is an m2m that we didn't encounter, then it must have been an empty list that the user wants
+ # to set.
+
+ for name in update_fields:
+ if (name in fmap) and (not name in fields_changed):
+ m2m = getattr(djangoClass, fmap[name].name)
+ m2m.clear()
+ fields_changed.append(name)
+
+ if fields_changed:
+ djangoClass.save()
+
def querysetToProto(self, djangoClass, queryset):
objs = queryset
p_objs = self.getPluralProtoClass(djangoClass)()
@@ -393,6 +463,9 @@
self.xos_security_gate(new_obj, user, write_access=True)
new_obj.save()
+
+ self.handle_m2m(new_obj, request, [])
+
return self.objToProto(new_obj)
def update(self, djangoClass, user, id, message, context):
@@ -405,16 +478,28 @@
for (k, v) in args.iteritems():
setattr(obj, k, v)
+ m2m_field_names = [x.name+"_ids" for x in djangoClass._meta.many_to_many]
+
+ update_fields = []
+ m2m_update_fields = []
save_kwargs = {}
for (k, v) in context.invocation_metadata():
if k == "update_fields":
- save_kwargs["update_fields"] = v.split(",")
+ for field_name in v.split(","):
+ if field_name in m2m_field_names:
+ m2m_update_fields.append(field_name)
+ else:
+ update_fields.append(field_name)
+ save_kwargs["update_fields"] = update_fields
elif k == "caller_kind":
save_kwargs["caller_kind"] = v
elif k == "always_update_timestamp":
save_kwargs["always_update_timestamp"] = True
obj.save(**save_kwargs)
+
+ self.handle_m2m(obj, message, m2m_update_fields)
+
return self.objToProto(obj)
def delete(self, djangoClass, user, id):
diff --git a/xos/coreapi/protos/xosoptions.proto b/xos/coreapi/protos/xosoptions.proto
index 7602c95..008f3de 100644
--- a/xos/coreapi/protos/xosoptions.proto
+++ b/xos/coreapi/protos/xosoptions.proto
@@ -18,10 +18,15 @@
string modelName = 1;
}
+message ManyToManyForeignKeyRule {
+ string modelName = 1;
+}
+
extend google.protobuf.FieldOptions {
ValRule val = 1001;
ForeignKeyRule foreignKey = 1002;
ReverseForeignKeyRule reverseForeignKey = 1003;
+ ManyToManyForeignKeyRule manyToManyForeignKey = 1004;
}
extend google.protobuf.MessageOptions {
diff --git a/xos/xos_client/tests/orm_nodelabel.py b/xos/xos_client/tests/orm_nodelabel.py
new file mode 100644
index 0000000..89dd025
--- /dev/null
+++ b/xos/xos_client/tests/orm_nodelabel.py
@@ -0,0 +1,163 @@
+
+# Copyright 2017-present Open Networking Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# These are functional tests of ManyToMany relations. These tests need to be conducted end-to-end with a real
+# API to verify that the client and server ends of the API are working with each other.
+
+import random
+import string
+import sys
+import unittest
+
+orm = None
+
+from xosapi import xos_grpc_client
+
+TEST_NODE_LABEL_1_NAME = "test_node_label_1"
+
+class TestORM(unittest.TestCase):
+ def setUp(self):
+ self.test_node_label_1_name = TEST_NODE_LABEL_1_NAME + "_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))
+
+ nodes1 = orm.Node.objects.filter(name="test_node_1")
+ if nodes1:
+ self.node1=nodes1[0]
+ else:
+ self.node1 = orm.Node(name="test_node_1", site_deployment=orm.SiteDeployment.objects.first())
+ self.node1.save()
+
+ nodes2 = orm.Node.objects.filter(name="test_node_2")
+ if nodes2:
+ self.node2=nodes2[0]
+ else:
+ self.node2 = orm.Node(name="test_node_2", site_deployment=orm.SiteDeployment.objects.first())
+ self.node2.save()
+
+ def tearDown(self):
+ # TODO: Deleting NodeLabel seems to be broken -- appears to be a cascade failure
+ # attaching a nodelabel to a node causes deleting the node to also be broken.
+
+
+ #node_labels1 = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+ #for node_label in node_labels1:
+ # node_label.delete()
+
+ #nodes1 = orm.Node.objects.filter(name="test_node_1")
+ #for node in nodes1:
+ # node.delete()
+
+ #nodes2 = orm.Node.objects.filter(name="test_node_2")
+ #for node in nodes2:
+ # node.delete()
+
+ pass
+
+ def test_create_empty_node_label(self):
+ n = orm.NodeLabel(name = self.test_node_label_1_name)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+ self.assertEqual(len(labels),1)
+
+ n=labels[0]
+ self.assertNotEqual(n, None)
+ self.assertEqual(len(n.node.all()), 0)
+
+ def test_create_node_label_one_node(self):
+ n = orm.NodeLabel(name = self.test_node_label_1_name)
+ n.node.add(self.node1)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+ self.assertEqual(len(labels),1)
+
+ n=labels[0]
+ self.assertNotEqual(n, None)
+ self.assertEqual(len(n.node.all()), 1)
+
+ def test_create_node_label_two_nodes(self):
+ n = orm.NodeLabel(name = self.test_node_label_1_name)
+ n.node.add(self.node1)
+ n.node.add(self.node2)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+ self.assertEqual(len(labels),1)
+
+ n=labels[0]
+ self.assertNotEqual(n, None)
+ self.assertEqual(len(n.node.all()), 2)
+
+ def test_add_node_to_label(self):
+ n = orm.NodeLabel(name = self.test_node_label_1_name)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n=labels[0]
+ n.node.add(self.node1)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name = self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n=labels[0]
+ self.assertEqual(len(n.node.all()), 1)
+
+ def test_remove_node_from_label(self):
+ n = orm.NodeLabel(name=self.test_node_label_1_name)
+ n.node.add(self.node1)
+ n.node.add(self.node2)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n = labels[0]
+ self.assertEqual(len(n.node.all()), 2)
+ n.node.remove(self.node1)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n = labels[0]
+ self.assertEqual(len(n.node.all()), 1)
+
+ def test_remove_last_node_from_label(self):
+ n = orm.NodeLabel(name=self.test_node_label_1_name)
+ n.node.add(self.node1)
+ n.save()
+
+ labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n = labels[0]
+ self.assertEqual(len(n.node.all()), 1)
+ n.node.remove(self.node1)
+ n.save(update_fields=["node_ids"])
+
+ labels = orm.NodeLabel.objects.filter(name=self.test_node_label_1_name)
+ self.assertEqual(len(labels), 1)
+ n = labels[0]
+ self.assertEqual(len(n.node.all()), 0)
+
+
+def test_callback():
+ global orm
+
+ orm = xos_grpc_client.coreclient.xos_orm
+
+ sys.argv=sys.argv[:1] # unittest gets mad about the orm command line arguments
+ unittest.main()
+
+xos_grpc_client.start_api_parseargs(test_callback)
+
diff --git a/xos/xos_client/xosapi/orm.py b/xos/xos_client/xosapi/orm.py
index ed84f6e..77d3e7c 100644
--- a/xos/xos_client/xosapi/orm.py
+++ b/xos/xos_client/xosapi/orm.py
@@ -100,8 +100,14 @@
if name.endswith("_ids"):
reverseForeignKey = field.GetOptions().Extensions._FindExtensionByName("xos.reverseForeignKey")
fk = field.GetOptions().Extensions[reverseForeignKey]
- if fk:
- reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName}
+ if fk and fk.modelName:
+ reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName, "writeable": False}
+ else:
+ manyToManyForeignKey = field.GetOptions().Extensions._FindExtensionByName("xos.manyToManyForeignKey")
+ fk = field.GetOptions().Extensions[manyToManyForeignKey]
+ if fk and fk.modelName:
+ reverse_fkmap[name[:-4]] = {"src_fieldName": name, "modelName": fk.modelName, "writeable": True}
+
return reverse_fkmap
@@ -133,7 +139,7 @@
def reverse_fk_resolve(self, name):
if name not in self.reverse_cache:
fk_entry = self._reverse_fkmap[name]
- self.cache[name] = ORMLocalObjectManager(self.stub, fk_entry["modelName"], getattr(self, fk_entry["src_fieldName"]))
+ self.cache[name] = ORMLocalObjectManager(self.stub, fk_entry["modelName"], getattr(self, fk_entry["src_fieldName"]), fk_entry["writeable"])
return self.cache[name]
@@ -284,10 +290,11 @@
class ORMLocalObjectManager(object):
""" Manages a local list of objects """
- def __init__(self, stub, modelName, idList):
+ def __init__(self, stub, modelName, idList, writeable):
self._stub = stub
self._modelName = modelName
self._idList = idList
+ self._writeable = writeable
self._cache = None
def resolve_queryset(self):
@@ -319,6 +326,32 @@
else:
return None
+ def add(self, model):
+ if not self._writeable:
+ raise Exception("Only ManyToMany lists are writeable")
+
+ if isinstance(model, int):
+ id = model
+ else:
+ if not model.id:
+ raise Exception("Model %s has no id" % model)
+ id = model.id
+
+ self._idList.append(id)
+
+ def remove(self, model):
+ if not self._writeable:
+ raise Exception("Only ManyToMany lists are writeable")
+
+ if isinstance(model, int):
+ id = model
+ else:
+ if not model.id:
+ raise Exception("Model %s has no id" % model)
+ id = model.id
+
+ self._idList.remove(id)
+
class ORMObjectManager(object):
""" Manages a remote list of objects """