Use differential packet count in port_stats + cleanup
diff --git a/tests/port_stats.py b/tests/port_stats.py
index facef3e..087e725 100644
--- a/tests/port_stats.py
+++ b/tests/port_stats.py
@@ -86,6 +86,59 @@
     obj.assertEqual(str(pkt), str(rcv_pkt),
                     'Response packet does not match send packet')
 
+def getStats(obj, port):
+    stat_req = message.port_stats_request()
+    stat_req.port_no = port
+
+    fs_logger.info("Sending stats request")
+    response, pkt = obj.controller.transact(stat_req, timeout=2)
+    obj.assertTrue(response is not None, 
+                    "No response to stats request")
+    obj.assertTrue(len(response.stats) == 1,
+                    "Did not receive port stats reply")
+    for item in response.stats:
+        fs_logger.info("Sent " + str(item.tx_packets) + " packets")
+        packet_sent = item.tx_packets
+        packet_recv = item.rx_packets
+    fs_logger.info("Port %d stats count: tx %d rx %d" % (port, packet_sent, packet_recv))
+    return packet_sent, packet_recv
+
+def verifyStats(obj, port, test_timeout, packet_sent, packet_recv):
+    stat_req = message.port_stats_request()
+    stat_req.port_no = port
+
+    all_packets_received = 0
+    all_packets_sent = 0
+    sent = recv = 0
+    for i in range(0,test_timeout):
+        fs_logger.info("Sending stats request")
+        response, pkt = obj.controller.transact(stat_req,
+                                                timeout=test_timeout)
+        obj.assertTrue(response is not None, 
+                       "No response to stats request")
+        obj.assertTrue(len(response.stats) == 1,
+                       "Did not receive port stats reply")
+        for item in response.stats:
+            sent = item.tx_packets
+            recv = item.rx_packets
+            fs_logger.info("Sent " + str(item.tx_packets) + " packets")
+            if item.tx_packets == packet_sent:
+                all_packets_sent = 1
+            fs_logger.info("Received " + str(item.rx_packets) + " packets")
+            if item.rx_packets == packet_recv:
+                all_packets_received = 1
+
+        if all_packets_received and all_packets_sent:
+            break
+        sleep(1)
+
+    fs_logger.info("Expected port %d stats count: tx %d rx %d" % (port, packet_sent, packet_recv))
+    fs_logger.info("Actual port %d stats count: tx %d rx %d" % (port, sent, recv))
+    obj.assertTrue(all_packets_sent,
+                   "Packet sent does not match number sent")
+    obj.assertTrue(all_packets_received,
+                   "Packet received does not match number sent")
+
 class SingleFlowStats(basic.SimpleDataPlane):
     """
     Verify flow stats are properly retrieved.
@@ -97,37 +150,6 @@
     Verify that the packet counter has incremented
     """
 
-    def verifyStats(self, port, test_timeout, packet_sent, packet_recv):
-        stat_req = message.port_stats_request()
-        stat_req.port_no = port
-
-        all_packets_received = 0
-        all_packets_sent = 0
-        for i in range(0,test_timeout):
-            fs_logger.info("Sending stats request")
-            response, pkt = self.controller.transact(stat_req,
-                                                     timeout=test_timeout)
-            self.assertTrue(response is not None, 
-                            "No response to stats request")
-            self.assertTrue(len(response.stats) == 1,
-                            "Did not receive port stats reply")
-            for obj in response.stats:
-                fs_logger.info("Sent " + str(obj.tx_packets) + " packets")
-                if obj.tx_packets == packet_sent:
-                    all_packets_sent = 1
-                fs_logger.info("Received " + str(obj.rx_packets) + " packets")
-                if obj.rx_packets == packet_recv:
-                    all_packets_received = 1
-
-            if all_packets_received and all_packets_sent:
-                break
-            sleep(1)
-
-        self.assertTrue(all_packets_sent,
-                        "Packet sent does not match number sent")
-        self.assertTrue(all_packets_received,
-                        "Packet received does not match number sent")
-
     def runTest(self):
         global fs_port_map
 
@@ -170,9 +192,9 @@
         self.assertTrue(rv != -1, "Error installing flow mod")
         self.assertEqual(do_barrier(self.controller), 0, "Barrier failed")
 
-        # no packets sent, so zero packet count
-        self.verifyStats(ingress_port, test_timeout, 0, 0)
-        self.verifyStats(egress_port, test_timeout, 0, 0)
+        # get initial port stats count
+        initTxInPort, initRxInPort = getStats(self, ingress_port)
+        initTxOutPort, initRxOutPort = getStats(self, egress_port)
 
         # send packet N times
         num_sends = random.randint(10,20)
@@ -181,8 +203,8 @@
             sendPacket(self, pkt, ingress_port, egress_port,
                        test_timeout)
 
-        self.verifyStats(ingress_port, test_timeout, 0, num_sends)
-        self.verifyStats(egress_port, test_timeout, num_sends, 0)
+        verifyStats(self, ingress_port, test_timeout, initTxInPort, initRxInPort + num_sends)
+        verifyStats(self, egress_port, test_timeout, initTxOutPort + num_sends, initRxOutPort)
 
 
 class MultiFlowStats(basic.SimpleDataPlane):
@@ -217,37 +239,6 @@
 
         return flow_mod_msg
 
-    def verifyStats(self, port, test_timeout, packet_sent, packet_recv):
-        stat_req = message.port_stats_request()
-        stat_req.port_no = port
-
-        all_packets_received = 0
-        all_packets_sent = 0
-        for i in range(0,test_timeout):
-            fs_logger.info("Sending stats request")
-            response, pkt = self.controller.transact(stat_req,
-                                                     timeout=test_timeout)
-            self.assertTrue(response is not None,
-                            "No response to stats request")
-            self.assertTrue(len(response.stats) == 1,
-                            "Did not receive port stats reply")
-            for obj in response.stats:
-                fs_logger.info("Sent " + str(obj.tx_packets) + " packets")
-                if obj.tx_packets == packet_sent:
-                    all_packets_sent = 1
-                fs_logger.info("Received " + str(obj.rx_packets) + " packets")
-                if obj.rx_packets == packet_recv:
-                    all_packets_received = 1
-
-            if all_packets_received and all_packets_sent:
-                break
-            sleep(1)
-
-        self.assertTrue(all_packets_sent,
-                        "Packet sent does not match number sent")
-        self.assertTrue(all_packets_received,
-                        "Packet received does not match number sent")
-
     def runTest(self):
         global fs_port_map
 
@@ -278,6 +269,11 @@
         self.assertTrue(rv != -1, "Error installing flow mod")
         self.assertEqual(do_barrier(self.controller), 0, "Barrier failed")
 
+        # get initial port stats count
+        initTxInPort, initRxInPort = getStats(self, ingress_port)
+        initTxOutPort1, initRxOutPort1 = getStats(self, egress_port1)
+        initTxOutPort2, initRxOutPort2 = getStats(self, egress_port2)
+
         num_pkt1s = random.randint(10,30)
         fs_logger.info("Sending " + str(num_pkt1s) + " pkt1s")
         num_pkt2s = random.randint(10,30)
@@ -286,7 +282,10 @@
             sendPacket(self, pkt1, ingress_port, egress_port1, test_timeout)
         for i in range(0,num_pkt2s):
             sendPacket(self, pkt2, ingress_port, egress_port2, test_timeout)
-            
-        self.verifyStats(ingress_port, test_timeout, 0, num_pkt1s + num_pkt2s)
-        self.verifyStats(egress_port1, test_timeout, num_pkt1s, 0)
-        self.verifyStats(egress_port2, test_timeout, num_pkt2s, 0)
+
+        verifyStats(self, ingress_port, test_timeout,
+                    initTxInPort, initRxInPort + num_pkt1s + num_pkt2s)
+        verifyStats(self, egress_port1, test_timeout,
+                    initTxOutPort1 + num_pkt1s, initRxOutPort1)
+        verifyStats(self, egress_port2, test_timeout,
+                    initTxOutPort2 + num_pkt2s, initRxOutPort2)