Merge branch 'kenc'
diff --git a/src/python/oftest/controller.py b/src/python/oftest/controller.py
index 29a605e..cfe2eda 100644
--- a/src/python/oftest/controller.py
+++ b/src/python/oftest/controller.py
@@ -218,8 +218,7 @@
                                   % (hdr.version, OFP_VERSION))
                 print "Version %d does not match OFTest version %d" % \
                     (hdr.version, OFP_VERSION)
-                self.active = False
-                self.switch_socket = None
+                self.disconnect()
                 return
 
             msg = of_message_parse(rawmsg)
@@ -290,7 +289,11 @@
                 sock.close()
                 return 0
 
-            (sock, addr) = self.listen_socket.accept()
+            try:
+                (sock, addr) = self.listen_socket.accept()
+            except:
+                self.logger.warning("Error on listen socket accept")
+                return -1
             self.socs.append(sock)
             self.logger.info("Incoming connection from %s" % str(addr))
 
@@ -357,25 +360,23 @@
         self.logger.info("Waiting for switch connection")
         self.socs = [self.listen_socket]
         self.dbg_state = "running"
+
         while self.active:
             try:
                 sel_in, sel_out, sel_err = \
                     select.select(self.socs, [], self.socs, 1)
             except:
                 print sys.exc_info()
-                self.logger.error("Select error, exiting")
-                self.active = False
-                break
+                self.logger.error("Select error, disconnecting")
+                self.disconnect()
 
             for s in sel_err:
-                self.logger.error("Got socket error on: " + str(s))
-                self.active = False
-                break
+                self.logger.error("Got socket error on: " + str(s) + ", disconnecting")
+                self.disconnect()
 
             for s in sel_in:
                 if self._socket_ready_handle(s) == -1:
-                    self.active = False
-                    break
+                    self.disconnect()
 
         # End of main loop
         self.dbg_state = "closing"
@@ -394,6 +395,30 @@
             timed_wait(self.connect_cv, lambda: self.switch_socket, timeout=timeout)
         return self.switch_socket is not None
         
+    def disconnect(self, timeout=-1):
+        """
+        If connected to a switch, disconnect.
+        """
+        if self.switch_socket:
+            self.socs.remove(self.switch_socket)
+            self.switch_socket.close()
+            self.switch_socket = None
+            self.switch_addr = None
+            with self.connect_cv:
+                self.connect_cv.notifyAll()
+
+    def wait_disconnected(self, timeout=-1):
+        """
+        @param timeout Block for up to timeout seconds. Pass -1 for the default.
+        @return Boolean, True if disconnected
+        """
+
+        with self.connect_cv:
+            timed_wait(self.connect_cv, 
+                       lambda: True if not self.switch_socket else None, 
+                       timeout=timeout)
+        return self.switch_socket is None
+        
     def kill(self):
         """
         Force the controller thread to quit
diff --git a/tests/cxn.py b/tests/cxn.py
index da15c9e..5524434 100644
--- a/tests/cxn.py
+++ b/tests/cxn.py
@@ -25,53 +25,32 @@
     """
 
     priority = -1
+    controllers = []
+    default_timeout = 2
 
     def controllerSetup(self, host, port):
-        self.controller = controller.Controller(host=host,port=port)
+        con = controller.Controller(host=host,port=port)
 
         # clean_shutdown should be set to False to force quit app
         self.clean_shutdown = True
         # disable initial hello so hello is under control of test
-        self.controller.initial_hello = False
+        con.initial_hello = False
 
-        self.controller.start()
-        #@todo Add an option to wait for a pkt transaction to ensure version
-        # compatibilty?
-        self.controller.connect(timeout=10)
-        self.assertTrue(self.controller.active,
-                        "Controller startup failed, not active")
-        self.assertTrue(self.controller.switch_addr is not None,
-                        "Controller startup failed, no switch addr")
+        con.start()
+        self.controllers.append(con)
 
     def setUp(self):
         logging.info("** START TEST CASE " + str(self))
 
-        self.test_timeout = test_param_get('handshake_timeout', default=60)
+        self.default_timeout = test_param_get('default_timeout',
+                                              default=2)
 
-    def inheritSetup(self, parent):
-        """
-        Inherit the setup of a parent
-
-        This allows running at test from within another test.  Do the
-        following:
-
-        sub_test = SomeTestClass()  # Create an instance of the test class
-        sub_test.inheritSetup(self) # Inherit setup of parent
-        sub_test.runTest()          # Run the test
-
-        Normally, only the parent's setUp and tearDown are called and
-        the state after the sub_test is run must be taken into account
-        by subsequent operations.
-        """
-        logging.info("** Setup " + str(self) + 
-                                    " inheriting from " + str(parent))
-        self.controller = parent.controller
-        
     def tearDown(self):
         logging.info("** END TEST CASE " + str(self))
-        self.controller.shutdown()
-        if self.clean_shutdown:
-            self.controller.join()
+        for con in self.controllers:
+            con.shutdown()
+            if self.clean_shutdown:
+                con.join()
 
     def runTest(self):
         # do nothing in the base case
@@ -90,18 +69,15 @@
     def runTest(self):
         self.controllerSetup(config["controller_host"],
                              config["controller_port"])
+        self.controllers[0].connect(self.default_timeout)
 
         logging.info("TCP Connected " + 
-                        str(self.controller.switch_addr))
+                     str(self.controllers[0].switch_addr))
         logging.info("Hello not sent, waiting for timeout")
 
         # wait for controller to die
-        count = 0
-        while self.controller.active and count < self.test_timeout:
-            time.sleep(1)
-            count = count + 1
-        self.assertTrue(not self.controller.active, 
-                        "Expected controller disconnect, but still active")
+        self.assertTrue(self.controllers[0].wait_disconnected(timeout=10),
+                        "Not notified of controller disconnect")
 
 class HandshakeNoFeaturesRequest(BaseHandshake):
     """
@@ -111,21 +87,18 @@
     def runTest(self):
         self.controllerSetup(config["controller_host"],
                              config["controller_port"])
+        self.controllers[0].connect(self.default_timeout)
 
         logging.info("TCP Connected " + 
-                                    str(self.controller.switch_addr))
+                     str(self.controllers[0].switch_addr))
         logging.info("Sending hello")
-        self.controller.message_send(message.hello())
+        self.controllers[0].message_send(message.hello())
 
         logging.info("Features request not sent, waiting for timeout")
 
         # wait for controller to die
-        count = 0
-        while self.controller.active and count < self.test_timeout:
-            time.sleep(1)
-            count = count + 1
-        self.assertTrue(not self.controller.active, 
-                        "Expected controller disconnect, but still active")
+        self.assertTrue(self.controllers[0].wait_disconnected(timeout=10),
+                        "Not notified of controller disconnect")
 
 class HandshakeAndKeepalive(BaseHandshake):
     """
@@ -136,27 +109,58 @@
     priority = -1
 
     def runTest(self):
-        self.controllerSetup(config["controller_host"],
-                             config["controller_port"])
+        self.num_controllers = test_param_get('num_controllers', default=1)
+        self.controller_timeout = test_param_get('controller_timeout',
+                                                 default=-1)
 
-        logging.info("TCP Connected " + 
-                                    str(self.controller.switch_addr))
-        logging.info("Sending hello")
-        self.controller.message_send(message.hello())
+        for i in range(self.num_controllers):
+            self.controllerSetup(config["controller_host"],
+                                 config["controller_port"]+i)
+        for i in range(self.num_controllers):
+            self.controllers[i].handshake_done = False
 
-        request = message.features_request()
-        reply, pkt = self.controller.transact(request, timeout=20)
-        self.assertTrue(reply is not None,
-                        "Did not complete features_request for handshake")
-        logging.info("Handshake complete with " + 
-                        str(self.controller.switch_addr))
-
-        self.controller.keep_alive = True
-
-        # keep controller up forever
-        while self.controller.active:
-            time.sleep(1)
-
-        self.assertTrue(not self.controller.active, 
-                        "Expected controller disconnect, but still active")
+        # try to maintain switch connections for specified timeout
+        # -1 means forever
+        while True:
+            for con in self.controllers:
+                if con.switch_socket and con.handshake_done:
+                    if (self.controller_timeout < 0 or
+                        con.count < self.controller_timeout):
+                        logging.info(con.host + ":" + str(con.port) + 
+                                     ": maintaining connection to " +
+                                     str(con.switch_addr))
+                        con.count = con.count + 1
+                    else:
+                        logging.info(con.host + ":" + str(con.port) + 
+                                     ": disconnecting from " +
+                                     str(con.switch_addr))
+                        con.disconnect()
+                        con.handshake_done = False
+                        con.count = 0
+                    time.sleep(1)
+                else:
+                    #@todo Add an option to wait for a pkt transaction to 
+                    # ensure version compatibilty?
+                    con.connect(self.default_timeout)
+                    if not con.switch_socket:
+                        logging.info("Did not connect to switch")
+                        continue
+                    logging.info("TCP Connected " + str(con.switch_addr))
+                    logging.info("Sending hello")
+                    con.message_send(message.hello())
+                    request = message.features_request()
+                    reply, pkt = con.transact(request, 
+                                              timeout=self.default_timeout)
+                    if reply:
+                        logging.info("Handshake complete with " + 
+                                    str(con.switch_addr))
+                        con.handshake_done = True
+                        con.keep_alive = True
+                        con.count = 0
+                    else:
+                        logging.info("Did not complete features_request " +
+                                     "for handshake")
+                        con.disconnect()
+                        con.handshake_done = False
+                        con.count = 0