CORD-2359 eliminate foreignkey is poisoned errors

Change-Id: I36a7523ffbd29a504269bc15bf9ffd418e2b2183
diff --git a/xos/xos_client/xosapi/orm.py b/xos/xos_client/xosapi/orm.py
index 77d3e7c..ae3278d 100644
--- a/xos/xos_client/xosapi/orm.py
+++ b/xos/xos_client/xosapi/orm.py
@@ -49,8 +49,8 @@
         super(ORMWrapper, self).__setattr__("cache", {})
         super(ORMWrapper, self).__setattr__("reverse_cache", {})
         super(ORMWrapper, self).__setattr__("synchronizer_step", None)
-        super(ORMWrapper, self).__setattr__("poisoned", {})
         super(ORMWrapper, self).__setattr__("is_new", is_new)
+        super(ORMWrapper, self).__setattr__("post_save_fixups", [])
         fkmap=self.gen_fkmap()
         super(ORMWrapper, self).__setattr__("_fkmap", fkmap)
         reverse_fkmap=self.gen_reverse_fkmap()
@@ -80,7 +80,10 @@
                foreignKey = field.GetOptions().Extensions._FindExtensionByName("xos.foreignKey")
                fk = field.GetOptions().Extensions[foreignKey]
                if fk and fk.modelName:
-                   fkmap[name[:-3]] = {"src_fieldName": name, "modelName": fk.modelName, "kind": "fk"}
+                   fkdict = {"src_fieldName": name, "modelName": fk.modelName, "kind": "fk"}
+                   if fk.reverseFieldName:
+                       fkdict["reverse_fieldName"] = fk.reverseFieldName
+                   fkmap[name[:-3]] = fkdict
                else:
                    # If there's a corresponding _type_id field, then see if this
                    # is a generic foreign key.
@@ -113,7 +116,7 @@
 
     def fk_resolve(self, name):
         if name in self.cache:
-            return make_ORMWrapper(self.cache[name], self.stub)
+            return self.cache[name]
 
         fk_entry = self._fkmap[name]
         fk_kind = fk_entry["kind"]
@@ -132,9 +135,10 @@
         else:
             raise Exception("unknown fk_kind")
 
+        dest_model = make_ORMWrapper(dest_model, self.stub)
         self.cache[name] = dest_model
 
-        return make_ORMWrapper(dest_model, self.stub)
+        return dest_model
 
     def reverse_fk_resolve(self, name):
         if name not in self.reverse_cache:
@@ -155,15 +159,50 @@
         if fk_kind=="generic_fk":
             setattr(self._wrapped_class, fk_entry["ct_fieldName"], model.self_content_type_id)
 
-        # XXX setting the cache here is a problematic, since the cached object's
-        # reverse foreign key pointers will not include the reference back
-        # to this object. Instead of setting the cache, let's poison the name
-        # and throw an exception if someone tries to get it.
+        if name in self.cache:
+            old_model = self.cache[name]
+            if fk_entry.get("reverse_fieldName"):
+                # Note this fk change so that we can update the destination model after we save.
+                self.post_save_fixups.append({"src_fieldName": fk_entry["src_fieldName"],
+                                              "dest_id": id,
+                                              "dest_model": old_model,
+                                              "remove": True,
+                                              "reverse_fieldName": fk_entry.get("reverse_fieldName")})
+            del self.cache[name]
 
-        # To work around this, explicitly call reset_cache(fieldName) and
-        # the ORM will reload the object.
+        if model:
+            self.cache[name] = model
+            if fk_entry.get("reverse_fieldName"):
+                # Note this fk change so that we can update the destination model after we save.
+                self.post_save_fixups.append({"src_fieldName": fk_entry["src_fieldName"],
+                                              "dest_id": id,
+                                              "dest_model": model,
+                                              "remove": False,
+                                              "reverse_fieldName": fk_entry.get("reverse_fieldName")})
+        elif name in self.cache:
+            del self.cache[name]
 
-        self.poisoned[name] = True
+    def do_post_save_fixups(self):
+        # Perform post-save foreign key fixups.
+        # Fixup the models that we've set a foreign key to so that their in-memory representation has the correct
+        # reverse foreign key back to us. We can only do this after a save, because self.id isn't known until
+        # after save.
+        # See unit test test_foreign_key_set_without_invalidate
+        for fixup in self.post_save_fixups:
+            model = fixup["dest_model"]
+            reverse_fieldName_ids = fixup["reverse_fieldName"] + "_ids"
+            if not hasattr(model, reverse_fieldName_ids):
+                continue
+            if fixup["remove"]:
+                reverse_ids = getattr(model, reverse_fieldName_ids)
+                if self.id in reverse_ids:
+                    reverse_ids.remove(self.id)
+            else:
+                reverse_ids = getattr(model, reverse_fieldName_ids)
+                if self.id not in reverse_ids:
+                    reverse_ids.append(self.id)
+            model.invalidate_cache(fixup["reverse_fieldName"])
+        self.post_save_fixups = []
 
     def __getattr__(self, name, *args, **kwargs):
         # note: getattr is only called for attributes that do not exist in
@@ -173,10 +212,6 @@
         if (name == "pk"):
             name = "id"
 
-        if name in self.poisoned.keys():
-            # see explanation in fk_set()
-            raise Exception("foreign key was poisoned")
-
         if name in self._fkmap.keys():
             return self.fk_resolve(name)
 
@@ -223,12 +258,9 @@
                 del self.cache[name]
             if name in self.reverse_cache:
                 del self.reverse_cache[name]
-            if name in self.poisoned:
-                del self.poisoned[name]
         else:
             self.cache.clear()
             self.reverse_cache.clear()
-            self.poisoned.clear()
 
     def save(self, update_fields=None, always_update_timestamp=False):
         if self.is_new:
@@ -242,6 +274,7 @@
            if always_update_timestamp:
                metadata.append( ("always_update_timestamp", "1") )
            self.stub.invoke("Update%s" % self._wrapped_class.__class__.__name__, self._wrapped_class, metadata=metadata)
+        self.do_post_save_fixups()
 
     def delete(self):
         id = self.stub.make_ID(id=self._wrapped_class.id)