Adding test classes, test builder with state machine generators, and test cases for Eap TLS
diff --git a/src/test/utils/CordTestBase.py b/src/test/utils/CordTestBase.py
new file mode 100644
index 0000000..6444c5a
--- /dev/null
+++ b/src/test/utils/CordTestBase.py
@@ -0,0 +1,24 @@
+class CordTester(object):
+
+ def __init__(self, fsmTable, stopState, stateTable = None, eventTable = None):
+ self.fsmTable = fsmTable
+ self.stopState = stopState
+ self.stateTable = stateTable
+ self.eventTable = eventTable
+ self.currentState = None
+ self.currentEvent = None
+ self.nextState = None
+ self.nextEvent = None
+
+ def runTest(self):
+ while self.currentState != self.stopState and self.currentEvent != None:
+ if self.stateTable and self.eventTable:
+ print 'Current state: %s, Current event: %s' %(self.stateTable.toStr(self.currentState),
+ self.eventTable.toStr(self.currentEvent))
+ key = (self.currentState, self.currentEvent)
+ (actions, nextState) = self.fsmTable[key]
+ if actions:
+ for a in actions:
+ a()
+ self.currentState = nextState
+ self.currentEvent = self.nextEvent
diff --git a/src/test/utils/EapTLS.py b/src/test/utils/EapTLS.py
new file mode 100644
index 0000000..575fb20
--- /dev/null
+++ b/src/test/utils/EapTLS.py
@@ -0,0 +1,92 @@
+import sys, os
+cord_root = os.getenv('CORD_TEST_ROOT') or './'
+CORD_TEST_FSM = 'fsm'
+sys.path.append(cord_root + CORD_TEST_FSM)
+from EapolAAA import *
+from enum import *
+import noseTlsAuthHolder as tlsAuthHolder
+from scapy_ssl_tls.ssl_tls import *
+from socket import *
+from struct import *
+import scapy
+from nose.tools import *
+from CordTestBase import CordTester
+
+class TLSAuthTest(EapolPacket, CordTester):
+
+ tlsStateTable = Enumeration("TLSStateTable", ("ST_EAP_SETUP",
+ "ST_EAP_START",
+ "ST_EAP_ID_REQ",
+ "ST_EAP_TLS_HELLO_REQ",
+ "ST_EAP_TLS_CERT_REQ",
+ "ST_EAP_TLS_DONE"
+ )
+ )
+ tlsEventTable = Enumeration("TLSEventTable", ("EVT_EAP_SETUP",
+ "EVT_EAP_START",
+ "EVT_EAP_ID_REQ",
+ "EVT_EAP_TLS_HELLO_REQ",
+ "EVT_EAP_TLS_CERT_REQ",
+ "EVT_EAP_TLS_DONE"
+ )
+ )
+ def __init__(self, intf = 'veth0'):
+ self.fsmTable = tlsAuthHolder.initTlsAuthHolderFsmTable(self, self.tlsStateTable, self.tlsEventTable)
+ EapolPacket.__init__(self, intf)
+ CordTester.__init__(self, self.fsmTable, self.tlsStateTable.ST_EAP_TLS_DONE)
+ #self.tlsStateTable, self.tlsEventTable)
+ self.currentState = self.tlsStateTable.ST_EAP_SETUP
+ self.currentEvent = self.tlsEventTable.EVT_EAP_SETUP
+ self.nextState = None
+ self.nextEvent = None
+
+ def _eapSetup(self):
+ print 'Inside EAP Setup'
+ self.setup()
+ self.nextEvent = self.tlsEventTable.EVT_EAP_START
+
+ def _eapStart(self):
+ print 'Inside EAP Start'
+ self.eapol_start()
+ self.nextEvent = self.tlsEventTable.EVT_EAP_ID_REQ
+
+ def _eapIdReq(self):
+ print 'Inside EAP ID Req'
+ p = self.eapol_recv()
+ code, pkt_id, eaplen = unpack("!BBH", p[0:4])
+ print "Code %d, id %d, len %d" %(code, pkt_id, eaplen)
+ assert_equal(code, EAP_REQUEST)
+ reqtype = unpack("!B", p[4:5])[0]
+ reqdata = p[5:4+eaplen]
+ assert_equal(reqtype, EAP_TYPE_ID)
+ print "<====== Send EAP Response with identity = %s ================>" % USER
+ self.eapol_id_req(pkt_id, USER)
+ self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_HELLO_REQ
+
+ def _eapTlsHelloReq(self):
+ print 'Inside EAP TLS Hello Req'
+ p = self.eapol_recv()
+ code, pkt_id, eaplen = unpack("!BBH", p[0:4])
+ print "Code %d, id %d, len %d" %(code, pkt_id, eaplen)
+ assert_equal(code, EAP_REQUEST)
+ reqtype = unpack("!B", p[4:5])[0]
+ assert_equal(reqtype, EAP_TYPE_TLS)
+ reqdata = TLSRecord(version="TLS_1_0")/TLSHandshake()/TLSClientHello(version="TLS_1_0",
+ gmt_unix_time=1234,
+ random_bytes="A" * 28,
+ session_id='',
+ compression_methods=(TLSCompressionMethod.NULL),
+ cipher_suites=[TLSCipherSuite.RSA_WITH_AES_128_CBC_SHA]
+ )
+
+ #reqdata.show()
+ print "------> Sending Client Hello TLS payload of len %d ----------->" %len(reqdata)
+ eap_payload = self.eapTLS(EAP_RESPONSE, pkt_id, TLS_LENGTH_INCLUDED, str(reqdata))
+ self.eapol_send(EAPOL_EAPPACKET, eap_payload)
+ self.nextEvent = self.tlsEventTable.EVT_EAP_TLS_CERT_REQ
+
+ def _eapTlsCertReq(self):
+ print 'Inside EAP TLS Cert Req'
+ p = self.eapol_recv()
+ print 'Got TLS Cert Req with payload len: %d' %len(p)
+ self.nextEvent = None
diff --git a/src/test/utils/EapolAAA.py b/src/test/utils/EapolAAA.py
new file mode 100644
index 0000000..ea26164
--- /dev/null
+++ b/src/test/utils/EapolAAA.py
@@ -0,0 +1,92 @@
+#### Authentication parameters
+from socket import *
+from struct import *
+import scapy
+import sys
+from nose.tools import assert_equal, assert_not_equal, assert_raises, assert_true
+
+USER = "raduser"
+PASS = "radpass"
+WRONG_USER = "XXXX"
+WRONG_PASS = "XXXX"
+NO_USER = ""
+NO_PASS = ""
+DEV = "tap0"
+ETHERTYPE_PAE = 0x888e
+PAE_GROUP_ADDR = "\xff\xff\xff\xff\xff\xff"
+EAPOL_VERSION = 1
+EAPOL_EAPPACKET = 0
+EAPOL_START = 1
+EAPOL_LOGOFF = 2
+EAPOL_KEY = 3
+EAPOL_ASF = 4
+EAP_REQUEST = 1
+EAP_RESPONSE = 2
+EAP_SUCCESS = 3
+EAP_FAILURE = 4
+EAP_TYPE_ID = 1
+EAP_TYPE_MD5 = 4
+EAP_TYPE_MSCHAP = 26
+EAP_TYPE_TLS = 13
+cCertMsg = '\x0b\x00\x00\x03\x00\x00\x00'
+TLS_LENGTH_INCLUDED = 0x80
+
+def ethernet_header(src, dst, req_type):
+ return dst+src+pack("!H", req_type)
+
+class EapolPacket(object):
+
+ def __init__(self, intf = 'veth0'):
+ self.intf = intf
+ self.s = None
+ self.max_payload_size = 1600
+
+ def setup(self):
+ self.s = socket(AF_PACKET, SOCK_RAW, htons(ETHERTYPE_PAE))
+ self.s.bind((self.intf, ETHERTYPE_PAE))
+ self.mymac = self.s.getsockname()[4]
+ self.llheader = ethernet_header(self.mymac, PAE_GROUP_ADDR, ETHERTYPE_PAE)
+
+ def cleanup(self):
+ if self.s is not None:
+ self.s.close()
+ self.s = None
+
+ def eapol(self, req_type, payload=""):
+ return pack("!BBH", EAPOL_VERSION, req_type, len(payload))+payload
+
+ def eap(self, code, pkt_id, req_type=0, data=""):
+ if code in [EAP_SUCCESS, EAP_FAILURE]:
+ return pack("!BBH", code, pkt_id, 4)
+ else:
+ return pack("!BBHB", code, pkt_id, 5+len(data), req_type)+data
+
+ def eapTLS(self, code, pkt_id, flags = TLS_LENGTH_INCLUDED, data=""):
+ req_type = EAP_TYPE_TLS
+ if code in [EAP_SUCCESS, EAP_FAILURE]:
+ return pack("!BBH", code, pkt_id, 4)
+ else:
+ if flags & TLS_LENGTH_INCLUDED:
+ flags_dlen = pack("!BL", flags, len(data))
+ return pack("!BBHB", code, pkt_id, 5+len(flags_dlen)+len(data), req_type) + flags_dlen + data
+ flags_str = pack("!B", flags)
+ return pack("!BBHB", code, pkt_id, 5+len(flags_str)+len(data), req_type) + flags_str + data
+
+ def eapol_send(self, eapol_type, eap_payload):
+ return self.s.send(self.llheader + self.eapol(eapol_type, eap_payload))
+
+ def eapol_recv(self):
+ p = self.s.recv(self.max_payload_size)[14:]
+ vers,pkt_type,eapollen = unpack("!BBH",p[:4])
+ print "Version %d, type %d, len %d" %(vers, pkt_type, eapollen)
+ assert_equal(pkt_type, EAPOL_EAPPACKET)
+ return p[4:]
+
+ def eapol_start(self):
+ eap_payload = self.eap(EAPOL_START, 2)
+ return self.eapol_send(EAPOL_START, eap_payload)
+
+ def eapol_id_req(self, pkt_id = 0, user = USER):
+ eap_payload = self.eap(EAP_RESPONSE, pkt_id, EAP_TYPE_ID, user)
+ return self.eapol_send(EAPOL_EAPPACKET, eap_payload)
+
diff --git a/src/test/utils/enum.py b/src/test/utils/enum.py
new file mode 100644
index 0000000..9157a2a
--- /dev/null
+++ b/src/test/utils/enum.py
@@ -0,0 +1,114 @@
+#!python
+import copy
+import pprint
+pf = pprint.pformat
+
+class EnumException(Exception):
+ pass
+class Enumeration(object):
+ def __init__(self, name, enumList, valuesAreUnique=False, startValue=0):
+ self.__doc__ = name
+ self.uniqueVals = valuesAreUnique
+ self.lookup = {}
+ self.reverseLookup = {}
+
+ self._addEnums(enumList, startValue)
+
+ def _addEnums(self, enumList, startValue):
+ i = startValue
+ for x in enumList:
+ if type(x) is tuple:
+ try:
+ x, i = x
+ except ValueError:
+ raise EnumException, "tuple doesn't have 2 items: %r" % (x,)
+ if type(x) is not str:
+ raise EnumException, "enum name is not a string: %r" % (x,)
+ if x in self.lookup:
+ raise EnumException, "enum name is not unique: %r" % (x,)
+ if self.uniqueVals and i in self.reverseLookup:
+ raise EnumException, "enum value %r not unique for %r" % (i, x)
+ self.lookup[x] = i
+ self.reverseLookup[i] = x
+
+ if type(i) is int:
+ i = i + 1
+
+ values = self.lookup.values()
+ self.first_int = min(values)
+ self.last_int = max(values)
+ self.first_name = self.reverseLookup[self.first_int]
+ self.last_name = self.reverseLookup[self.last_int]
+
+ def __str__(self):
+ return pf(self.lookup)
+
+ def __repr__(self):
+ return pf(self.lookup)
+
+ def __eq__(self, other):
+ return isinstance(other, Enumeration) and self.__doc__ == other.self.__doc__ and 0 == cmp(self.lookup, other.lookup)
+
+ def extend(self, enumList):
+ '''
+ Extend an existing enumeration with additional values.
+ '''
+ startValue = self.last_int + 1
+ self._addEnums(enumList, startValue)
+
+ def __getattr__(self, attr):
+ try: return self.lookup[attr]
+ except KeyError: raise AttributeError, attr
+
+ def whatis(self,value):
+ return self.reverseLookup[value]
+
+ def toInt(self, strval):
+ return self.lookup.get(strval)
+
+ def toStr(self,value):
+ return self.reverseLookup.get(value,"Value undefined: %s" % str(value))
+
+ def range(self):
+ keys = copy.copy(self.reverseLookup.keys())
+ keys.sort()
+ return keys
+
+ def valid(self, value):
+ return value in self.reverseLookup.keys()
+
+ def invalid(self, value):
+ return value not in self.reverseLookup.keys()
+
+ def vrange(self):
+ ''' returns an iterator of the enumeration values '''
+ return copy.copy(self.lookup.keys())
+
+ def first_asInt(self):
+ return self.first_int
+
+ def last_asInt(self):
+ return self.last_int
+
+ def first_asName(self):
+ return self.first_name
+
+ def last_asName(self):
+ return self.last_name
+
+if __name__ == '__main__':
+ #lets test things
+
+ testEnum0 = Enumeration("EnumName0",
+ ("Value0","Value1","Value2","Value3","Value4","Value5","Value6"))
+
+ print testEnum0.Value6
+
+ if testEnum0.__getattr__("Value6") == testEnum0.Value6:
+ print "Looks good"
+
+ # This is a bad case, we inserted a non-string value which should case
+ # an exception.
+# testEnum1 = Enumeration("EnumName1",
+# ("Value0","Value1","Value2",1,"Value3","Value4","Value5","Value6"))
+