CORD-2359 eliminate foreignkey is poisoned errors

Change-Id: I36a7523ffbd29a504269bc15bf9ffd418e2b2183
diff --git a/lib/xos-genx/xosgenx/jinja2_extensions/base.py b/lib/xos-genx/xosgenx/jinja2_extensions/base.py
index 0849ab9..5943e66 100644
--- a/lib/xos-genx/xosgenx/jinja2_extensions/base.py
+++ b/lib/xos-genx/xosgenx/jinja2_extensions/base.py
@@ -231,6 +231,8 @@
 
     if 'link' in field and 'model' in field['options']:
         options.append('(foreignKey).modelName = "%s"'%field['options']['model'])
+        if ("options" in field) and ("port" in field["options"]):
+            options.append('(foreignKey).reverseFieldName = "%s"' % field['options']['port'])
 
     if options:
         options_str = '[' + ', '.join(options) + ']'
diff --git a/xos/coreapi/protos/xosoptions.proto b/xos/coreapi/protos/xosoptions.proto
index 008f3de..0494ada 100644
--- a/xos/coreapi/protos/xosoptions.proto
+++ b/xos/coreapi/protos/xosoptions.proto
@@ -12,6 +12,7 @@
 
 message ForeignKeyRule {
   string modelName = 1;
+  string reverseFieldName = 2;
 }
 
 message ReverseForeignKeyRule {
diff --git a/xos/xos_client/xosapi/fake_stub.py b/xos/xos_client/xosapi/fake_stub.py
index edccf21..29a8962 100644
--- a/xos/xos_client/xosapi/fake_stub.py
+++ b/xos/xos_client/xosapi/fake_stub.py
@@ -71,8 +71,9 @@
         return default
 
 class FakeFieldOption(object):
-    def __init__(self, modelName=None):
+    def __init__(self, modelName=None, reverseFieldName=None):
         self.modelName = modelName
+        self.reverseFieldName = reverseFieldName
 
 class FakeField(object):
     def __init__(self, field):
@@ -80,7 +81,8 @@
 
         fk_model = field.get("fk_model", None)
         if fk_model:
-            extensions["xos.foreignKey"] = FakeFieldOption(modelName=fk_model)
+            reverseFieldName = field.get("fk_reverseFieldName", None)
+            extensions["xos.foreignKey"] = FakeFieldOption(modelName=fk_model, reverseFieldName=reverseFieldName)
 
         fk_reverse = field.get("fk_reverse", None)
         if fk_reverse:
@@ -150,7 +152,7 @@
 class Slice(FakeObj):
     FIELDS = ( {"name": "id", "default": 0},
                {"name": "name", "default": ""},
-               {"name": "site_id", "default": 0, "fk_model": "Site"},
+               {"name": "site_id", "default": 0, "fk_model": "Site", "fk_reverseFieldName": "slices"},
                {"name": "service_id", "default": 0, "fk_model": "Service"},
                {"name": "creator_id", "default": 0, "fk_model": "User"},
                {"name": "networks_ids", "default": [], "fk_reverse": "Network"},
@@ -165,7 +167,7 @@
 class Site(FakeObj):
     FIELDS = ( {"name": "id", "default": 0},
                {"name": "name", "default": ""},
-               {"name": "slice_ids", "default": [], "fk_reverse": "Slice"},
+               {"name": "slices_ids", "default": [], "fk_reverse": "Slice"},
                {"name": "leaf_model_name", "default": "Site"})
 
     def __init__(self, **kwargs):
diff --git a/xos/xos_client/xosapi/orm.py b/xos/xos_client/xosapi/orm.py
index 77d3e7c..ae3278d 100644
--- a/xos/xos_client/xosapi/orm.py
+++ b/xos/xos_client/xosapi/orm.py
@@ -49,8 +49,8 @@
         super(ORMWrapper, self).__setattr__("cache", {})
         super(ORMWrapper, self).__setattr__("reverse_cache", {})
         super(ORMWrapper, self).__setattr__("synchronizer_step", None)
-        super(ORMWrapper, self).__setattr__("poisoned", {})
         super(ORMWrapper, self).__setattr__("is_new", is_new)
+        super(ORMWrapper, self).__setattr__("post_save_fixups", [])
         fkmap=self.gen_fkmap()
         super(ORMWrapper, self).__setattr__("_fkmap", fkmap)
         reverse_fkmap=self.gen_reverse_fkmap()
@@ -80,7 +80,10 @@
                foreignKey = field.GetOptions().Extensions._FindExtensionByName("xos.foreignKey")
                fk = field.GetOptions().Extensions[foreignKey]
                if fk and fk.modelName:
-                   fkmap[name[:-3]] = {"src_fieldName": name, "modelName": fk.modelName, "kind": "fk"}
+                   fkdict = {"src_fieldName": name, "modelName": fk.modelName, "kind": "fk"}
+                   if fk.reverseFieldName:
+                       fkdict["reverse_fieldName"] = fk.reverseFieldName
+                   fkmap[name[:-3]] = fkdict
                else:
                    # If there's a corresponding _type_id field, then see if this
                    # is a generic foreign key.
@@ -113,7 +116,7 @@
 
     def fk_resolve(self, name):
         if name in self.cache:
-            return make_ORMWrapper(self.cache[name], self.stub)
+            return self.cache[name]
 
         fk_entry = self._fkmap[name]
         fk_kind = fk_entry["kind"]
@@ -132,9 +135,10 @@
         else:
             raise Exception("unknown fk_kind")
 
+        dest_model = make_ORMWrapper(dest_model, self.stub)
         self.cache[name] = dest_model
 
-        return make_ORMWrapper(dest_model, self.stub)
+        return dest_model
 
     def reverse_fk_resolve(self, name):
         if name not in self.reverse_cache:
@@ -155,15 +159,50 @@
         if fk_kind=="generic_fk":
             setattr(self._wrapped_class, fk_entry["ct_fieldName"], model.self_content_type_id)
 
-        # XXX setting the cache here is a problematic, since the cached object's
-        # reverse foreign key pointers will not include the reference back
-        # to this object. Instead of setting the cache, let's poison the name
-        # and throw an exception if someone tries to get it.
+        if name in self.cache:
+            old_model = self.cache[name]
+            if fk_entry.get("reverse_fieldName"):
+                # Note this fk change so that we can update the destination model after we save.
+                self.post_save_fixups.append({"src_fieldName": fk_entry["src_fieldName"],
+                                              "dest_id": id,
+                                              "dest_model": old_model,
+                                              "remove": True,
+                                              "reverse_fieldName": fk_entry.get("reverse_fieldName")})
+            del self.cache[name]
 
-        # To work around this, explicitly call reset_cache(fieldName) and
-        # the ORM will reload the object.
+        if model:
+            self.cache[name] = model
+            if fk_entry.get("reverse_fieldName"):
+                # Note this fk change so that we can update the destination model after we save.
+                self.post_save_fixups.append({"src_fieldName": fk_entry["src_fieldName"],
+                                              "dest_id": id,
+                                              "dest_model": model,
+                                              "remove": False,
+                                              "reverse_fieldName": fk_entry.get("reverse_fieldName")})
+        elif name in self.cache:
+            del self.cache[name]
 
-        self.poisoned[name] = True
+    def do_post_save_fixups(self):
+        # Perform post-save foreign key fixups.
+        # Fixup the models that we've set a foreign key to so that their in-memory representation has the correct
+        # reverse foreign key back to us. We can only do this after a save, because self.id isn't known until
+        # after save.
+        # See unit test test_foreign_key_set_without_invalidate
+        for fixup in self.post_save_fixups:
+            model = fixup["dest_model"]
+            reverse_fieldName_ids = fixup["reverse_fieldName"] + "_ids"
+            if not hasattr(model, reverse_fieldName_ids):
+                continue
+            if fixup["remove"]:
+                reverse_ids = getattr(model, reverse_fieldName_ids)
+                if self.id in reverse_ids:
+                    reverse_ids.remove(self.id)
+            else:
+                reverse_ids = getattr(model, reverse_fieldName_ids)
+                if self.id not in reverse_ids:
+                    reverse_ids.append(self.id)
+            model.invalidate_cache(fixup["reverse_fieldName"])
+        self.post_save_fixups = []
 
     def __getattr__(self, name, *args, **kwargs):
         # note: getattr is only called for attributes that do not exist in
@@ -173,10 +212,6 @@
         if (name == "pk"):
             name = "id"
 
-        if name in self.poisoned.keys():
-            # see explanation in fk_set()
-            raise Exception("foreign key was poisoned")
-
         if name in self._fkmap.keys():
             return self.fk_resolve(name)
 
@@ -223,12 +258,9 @@
                 del self.cache[name]
             if name in self.reverse_cache:
                 del self.reverse_cache[name]
-            if name in self.poisoned:
-                del self.poisoned[name]
         else:
             self.cache.clear()
             self.reverse_cache.clear()
-            self.poisoned.clear()
 
     def save(self, update_fields=None, always_update_timestamp=False):
         if self.is_new:
@@ -242,6 +274,7 @@
            if always_update_timestamp:
                metadata.append( ("always_update_timestamp", "1") )
            self.stub.invoke("Update%s" % self._wrapped_class.__class__.__name__, self._wrapped_class, metadata=metadata)
+        self.do_post_save_fixups()
 
     def delete(self):
         id = self.stub.make_ID(id=self._wrapped_class.id)
diff --git a/xos/xos_client/xosapi/test_orm.py b/xos/xos_client/xosapi/test_orm.py
index 17b6636..7b46308 100644
--- a/xos/xos_client/xosapi/test_orm.py
+++ b/xos/xos_client/xosapi/test_orm.py
@@ -143,7 +143,7 @@
         self.assertNotEqual(slice.site, None)
         self.assertEqual(slice.site.id, site.id)
 
-    def test_foreign_key_set(self):
+    def test_foreign_key_set_with_invalidate(self):
         orm = self.make_coreapi()
         site = orm.Site(name="mysite")
         site.save()
@@ -157,6 +157,124 @@
         self.assertTrue(slice.id > 0)
         self.assertNotEqual(slice.site, None)
         self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id in slice.site.slices_ids)
+
+    def test_foreign_key_set_without_invalidate(self):
+        orm = self.make_coreapi()
+        site = orm.Site(name="mysite")
+        site.save()
+        self.assertTrue(site.id > 0)
+        user = orm.User(email="fake_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)), site_id=site.id)
+        user.save()
+        self.assertTrue(user.id > 0)
+        slice = orm.Slice(name="mysite_foo", site = site, creator_id=user.id)
+        slice.save()
+        self.assertTrue(slice.id > 0)
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id in slice.site.slices_ids)
+            ids_from_models = [x.id for x in slice.site.slices.all()]
+            self.assertTrue(slice.id in ids_from_models)
+
+    def test_foreign_key_reset(self):
+        orm = self.make_coreapi()
+        site = orm.Site(name="mysite")
+        site.save()
+        self.assertTrue(site.id > 0)
+        user = orm.User(email="fake_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)), site_id=site.id)
+        user.save()
+        self.assertTrue(user.id > 0)
+        slice = orm.Slice(name="mysite_foo", site = site, creator_id=user.id)
+        slice.save()
+        self.assertTrue(slice.id > 0)
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id in site.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
+
+        site2 = orm.Site(name="mysite2")
+        site2.save()
+        slice.name = "mysite2_foo"
+        slice.site = site2
+        slice.save()
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site2.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id not in site.slices_ids)
+            self.assertTrue(slice.id in site2.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
+            ids_from_models1 = [x.id for x in site.slices.all()]
+            self.assertTrue(slice.id not in ids_from_models1)
+            ids_from_models2 = [x.id for x in site2.slices.all()]
+            self.assertTrue(slice.id in ids_from_models2)
+
+    def test_foreign_key_back_and_forth_even(self):
+        orm = self.make_coreapi()
+        site = orm.Site(name="mysite")
+        site.save()
+        self.assertTrue(site.id > 0)
+        user = orm.User(email="fake_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)), site_id=site.id)
+        user.save()
+        self.assertTrue(user.id > 0)
+        slice = orm.Slice(name="mysite_foo", site = site, creator_id=user.id)
+        slice.save()
+        self.assertTrue(slice.id > 0)
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id in site.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
+
+        site2 = orm.Site(name="mysite2")
+        site2.save()
+        slice.name = "mysite2_foo"
+        slice.site = site2
+        slice.site = site
+        slice.site = site2
+        slice.site = site
+        slice.save()
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id not in site2.slices_ids)
+            self.assertTrue(slice.id in site.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
+
+    def test_foreign_key_back_and_forth_odd(self):
+        orm = self.make_coreapi()
+        site = orm.Site(name="mysite")
+        site.save()
+        self.assertTrue(site.id > 0)
+        user = orm.User(email="fake_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)), site_id=site.id)
+        user.save()
+        self.assertTrue(user.id > 0)
+        slice = orm.Slice(name="mysite_foo", site = site, creator_id=user.id)
+        slice.save()
+        self.assertTrue(slice.id > 0)
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id in site.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
+
+        site2 = orm.Site(name="mysite2")
+        site2.save()
+        slice.name = "mysite2_foo"
+        slice.site = site2
+        slice.site = site
+        slice.site = site2
+        slice.site = site
+        slice.site = site2
+        slice.save()
+        self.assertNotEqual(slice.site, None)
+        self.assertEqual(slice.site.id, site2.id)
+        if not USE_FAKE_STUB:
+            self.assertTrue(slice.id not in site.slices_ids)
+            self.assertTrue(slice.id in site2.slices_ids)
+            self.assertTrue(slice.id in slice.site.slices_ids)
 
     def test_foreign_key_create_null(self):
         orm = self.make_coreapi()