Merge "Test : Adding CERT based TLS test cases."
diff --git a/src/test/utils/EapTLS.py b/src/test/utils/EapTLS.py
index b5aad78..04c2918 100644
--- a/src/test/utils/EapTLS.py
+++ b/src/test/utils/EapTLS.py
@@ -1,12 +1,12 @@
-# 
+#
 # Copyright 2016-present Ciena Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 # http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -105,7 +105,7 @@
         if server_hello_done[-4:] == self.server_hello_done_signature:
             self.server_hello_done_received = True
 
-    def __init__(self, intf = 'veth0'):
+    def __init__(self, intf = 'veth0', client_cert = None, client_priv_key = None, fail_cb = None):
         self.fsmTable = tlsAuthHolder.initTlsAuthHolderFsmTable(self, self.tlsStateTable, self.tlsEventTable)
         EapolPacket.__init__(self, intf)
         CordTester.__init__(self, self.fsmTable, self.tlsStateTable.ST_EAP_TLS_DONE)
@@ -127,6 +127,10 @@
                          self.SERVER_UNKNOWN: ['', '', lambda pkt: pkt ]
                        }
         self.tls_ctx = TLSSessionCtx(client = True)
+        self.client_cert = self.CLIENT_CERT if client_cert is None else client_cert
+        self.client_priv_key = self.CLIENT_PRIV_KEY if client_priv_key is None else client_priv_key
+        self.failTest = False
+        self.fail_cb = fail_cb
 
     def load_tls_record(self, data, pkt_type = ''):
         #if pkt_type not in [ self.SERVER_HELLO_DONE, self.SERVER_UNKNOWN ]:
@@ -154,6 +158,11 @@
             self.pkt_map[pkt_type][self.HDR_IDX] = ''
             self.pkt_map[pkt_type][self.DATA_IDX] = ''
 
+    def tlsFail(self):
+        ##Force a failure
+        self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_FINISHED
+        self.failTest = True
+
     def eapol_server_hello_cb(self, pkt):
         '''Reassemble and send response for server hello/certificate fragments'''
         r = str(pkt)
@@ -212,7 +221,7 @@
     def _eapSetup(self):
         self.setup()
         self.nextEvent = self.tlsEventTable.EVT_EAP_START
-        
+
     def _eapStart(self):
         self.eapol_start()
         self.nextEvent = self.tlsEventTable.EVT_EAP_ID_REQ
@@ -225,10 +234,14 @@
                 log.info("<====== Send EAP Response with identity = %s ================>" % USER)
                 self.eapol_id_req(pkt[EAP].id, USER)
 
-        self.eapol_scapy_recv(cb = eapol_cb,
-                              lfilter =
-                              lambda pkt: EAP in pkt and pkt[EAP].type == EAP.TYPE_ID and pkt[EAP].code == EAP.REQUEST)
-        self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_HELLO_REQ
+        r = self.eapol_scapy_recv(cb = eapol_cb,
+                                  lfilter =
+                                  lambda pkt: EAP in pkt and pkt[EAP].type == EAP.TYPE_ID and pkt[EAP].code == EAP.REQUEST)
+        if len(r) > 0:
+            self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_HELLO_REQ
+        else:
+            self.tlsFail()
+            return r
 
     def _eapTlsHelloReq(self):
 
@@ -249,14 +262,22 @@
                 eap_payload = self.eapTLS(EAP_RESPONSE, pkt[EAP].id, TLS_LENGTH_INCLUDED, str(reqdata))
                 self.eapol_send(EAPOL_EAPPACKET, eap_payload)
 
-        self.eapol_scapy_recv(cb = eapol_cb,
-                              lfilter =
-                              lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
-
-        for i in range(2):
-            self.eapol_scapy_recv(cb = self.eapol_server_hello_cb,
+        r = self.eapol_scapy_recv(cb = eapol_cb,
                                   lfilter =
                                   lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
+
+        if len(r) == 0:
+            self.tlsFail()
+            return r
+
+        for i in range(2):
+            r = self.eapol_scapy_recv(cb = self.eapol_server_hello_cb,
+                                      lfilter =
+                                      lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
+            if len(r) == 0:
+                self.tlsFail()
+                return r
+
         ##send cert request when we receive the last server hello fragment
         self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_CERT_REQ
 
@@ -293,18 +314,22 @@
                 self.eapol_server_hello_cb(pkt)
                 assert self.server_hello_done_received == True
                 rex_pem = re.compile(r'\-+BEGIN[^\-]+\-+(.*?)\-+END[^\-]+\-+', re.DOTALL)
-                der_cert = rex_pem.findall(self.CLIENT_CERT)[0].decode("base64")
-                client_certificate_list = TLSHandshake()/TLSCertificateList(
-                    certificates=[TLSCertificate(data=x509.X509Cert(der_cert))])
+                if self.client_cert:
+                    der_cert = rex_pem.findall(self.client_cert)[0].decode("base64")
+                    client_certificate_list = TLSHandshake()/TLSCertificateList(
+                        certificates=[TLSCertificate(data=x509.X509Cert(der_cert))])
+                else:
+                    client_certificate_list = TLSHandshake()/TLSCertificateList(certificates=[])
                 client_certificate = TLSRecord(version="TLS_1_0")/client_certificate_list
                 kex_data = self.tls_ctx.get_client_kex_data()
                 client_key_ex_data = TLSHandshake()/kex_data
                 client_key_ex = TLSRecord()/client_key_ex_data
-                self.load_tls_record(str(client_certificate))
+                if self.client_cert:
+                    self.load_tls_record(str(client_certificate))
+                    self.pkt_history.append(str(client_certificate_list))
                 self.load_tls_record(str(client_key_ex))
-                self.pkt_history.append(str(client_certificate_list))
                 self.pkt_history.append(str(client_key_ex_data))
-                verify_signature = self.get_verify_signature(self.CLIENT_PRIV_KEY)
+                verify_signature = self.get_verify_signature(self.client_priv_key)
                 client_cert_verify = TLSHandshake(type=TLSHandshakeType.CERTIFICATE_VERIFY)/verify_signature
                 client_cert_record = TLSRecord(content_type=TLSContentType.HANDSHAKE)/client_cert_verify
                 self.pkt_history.append(str(client_cert_verify))
@@ -318,10 +343,14 @@
                 eap_payload = self.eapTLS(EAP_RESPONSE, pkt[EAP].id, TLS_LENGTH_INCLUDED, reqdata)
                 self.eapol_send(EAPOL_EAPPACKET, eap_payload)
 
-        self.eapol_scapy_recv(cb = eapol_cb,
-                              lfilter =
-                              lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
-        self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_CHANGE_CIPHER_SPEC
+        r = self.eapol_scapy_recv(cb = eapol_cb,
+                                  lfilter =
+                                  lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
+        if len(r) > 0:
+            self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_CHANGE_CIPHER_SPEC
+        else:
+            self.tlsFail()
+            return r
 
     def _eapTlsChangeCipherSpec(self):
         def eapol_cb(pkt):
@@ -333,18 +362,30 @@
             eap_payload = self.eapTLS(EAP_RESPONSE, pkt[EAP].id, 0, '')
             self.eapol_send(EAPOL_EAPPACKET, eap_payload)
 
-        self.eapol_scapy_recv(cb = eapol_cb,
-                              lfilter =
-                              lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
-        self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_FINISHED
+        r = self.eapol_scapy_recv(cb = eapol_cb,
+                                  lfilter =
+                                  lambda pkt: EAP in pkt and pkt[EAP].type == EAP_TYPE_TLS and pkt[EAP].code == EAP.REQUEST)
+        if len(r) > 0:
+            self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_FINISHED
+        else:
+            self.tlsFail()
+            return r
 
     def _eapTlsFinished(self):
 
         def eapol_cb(pkt):
             log.info('Server authentication successfull')
 
+        timeout = 5
+        if self.failTest is True:
+            if self.fail_cb is not None:
+                self.fail_cb()
+                return
+            timeout = None ##Wait forever on failure and force testcase timeouts
+
         self.eapol_scapy_recv(cb = eapol_cb,
                               lfilter =
-                              lambda pkt: EAP in pkt and pkt[EAP].code == EAP.SUCCESS)
+                              lambda pkt: EAP in pkt and pkt[EAP].code == EAP.SUCCESS,
+                              timeout = timeout)
         self.eapol_logoff()
         self.nextEvent = None
diff --git a/src/test/utils/EapolAAA.py b/src/test/utils/EapolAAA.py
index a897833..0a2f8bd 100644
--- a/src/test/utils/EapolAAA.py
+++ b/src/test/utils/EapolAAA.py
@@ -1,12 +1,12 @@
-# 
+#
 # Copyright 2016-present Ciena Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 # http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -47,7 +47,7 @@
 TLS_LENGTH_INCLUDED = 0x80
 
 class EapolPacket(object):
-    
+
     def __init__(self, intf = 'veth0'):
         self.intf = intf
         self.s = None
@@ -64,7 +64,7 @@
         if self.s is not None:
             self.s.close()
             self.s = None
-            
+
     def eapol(self, req_type, payload=""):
         return EAPOL(version = EAPOL_VERSION, type = req_type)/payload
 
@@ -92,11 +92,11 @@
         assert_equal(pkt_type, EAPOL_EAPPACKET)
         return p[4:]
 
-    def eapol_scapy_recv(self, cb = None, lfilter = None, count = 1):
+    def eapol_scapy_recv(self, cb = None, lfilter = None, count = 1, timeout = 5):
         def eapol_default_cb(pkt): pass
         if cb is None:
             cb = eapol_default_cb
-        sniff(prn = cb, lfilter = lfilter, count = count, opened_socket = self.recv_sock)
+        return sniff(prn = cb, lfilter = lfilter, count = count, timeout = timeout, opened_socket = self.recv_sock)
 
     def eapol_start(self):
         eap_payload = self.eap(EAPOL_START, 2)