CORD-2754 workaround for gRPC connectivity callback sometimes not returning failure

Change-Id: I28e84349b408e66775661e4403dc8167716777e1
diff --git a/grpc_client/grpc_client.py b/grpc_client/grpc_client.py
index 730887c..2528eae 100644
--- a/grpc_client/grpc_client.py
+++ b/grpc_client/grpc_client.py
@@ -25,6 +25,7 @@
 from random import randint
 from zlib import decompress
 
+import functools
 import grpc
 from consul import Consul
 from grpc._channel import _Rendezvous
@@ -87,7 +88,7 @@
         self.reconnect_callback = reconnect_callback
         return self
 
-    def connectivity_callback(self, connectivity):
+    def connectivity_callback(self, client, connectivity):
         if (self.was_connected) and (connectivity in [connectivity.TRANSIENT_FAILURE, connectivity.FATAL_FAILURE, connectivity.SHUTDOWN]):
             log.info("connectivity lost -- restarting")
             os.execv(sys.executable, ['python'] + sys.argv)
@@ -95,6 +96,13 @@
         if (connectivity == connectivity.READY):
             self.was_connected = True
 
+        # Sometimes gRPC transitions from READY to IDLE, skipping TRANSIENT_FAILURE even though a socket is
+        # disconnected. So on idle, force a connectivity check.
+        if (connectivity == connectivity.IDLE) and (self.was_connected):
+            connectivity = client.channel._channel.check_connectivity_state(True)
+            # The result will probably show IDLE, but passing in True has the side effect of reconnecting if the
+            # connection has been lost, which will trigger the TRANSIENT_FALURE we were looking for.
+
     @inlineCallbacks
     def connect(self):
         """
@@ -119,7 +127,8 @@
                 self.channel = grpc.insecure_channel(_endpoint)
 
             if self.restart_on_disconnect:
-                self.channel.subscribe(self.connectivity_callback)
+                connectivity_callback = functools.partial(self.connectivity_callback, self)
+                self.channel.subscribe(connectivity_callback)
 
             swagger_from = self._retrieve_schema()
             self._compile_proto_files(swagger_from)