Refined OMCI parsing and serialization

Change-Id: I52f5a1fff997c41de51022fc6f8d4293b191994d
diff --git a/tests/utests/voltha/adapters/microsemi/test_chat.py b/tests/utests/voltha/adapters/microsemi/test_chat.py
index d0f2afc..94eb805 100644
--- a/tests/utests/voltha/adapters/microsemi/test_chat.py
+++ b/tests/utests/voltha/adapters/microsemi/test_chat.py
@@ -46,7 +46,7 @@
             reference_msg.show()
             hexdump(str(reference_msg))
             self.fail("Decoded message did not match! "
-                      "(inspect above printouts")
+                      "(inspect above printouts)")
         self.assertEqual(pas5211_msg, reference_msg)
 
     # ~~~~~~~~~~~~~~~~~~~~~~~~ test_get_protocol_version
@@ -489,7 +489,6 @@
 
     # ~~~~~~~~~~~~~~~~~~~~~~~~ test_frame_received_event
 
-    '''
     def test_frame_received_event(self):
         self.check_parsed(
             '\x90\xe2\xba\x82\xf9w\x00\x0c\xd5\x00\x01\x01\x00Z\x01\x00\x01'
@@ -497,7 +496,7 @@
             '\x00\x01\x00\x00\x000\x00\x00\x00\x00\x00\x01\x00\x15\x00 \x00'
             '\x13\x00\x00 \x00\x00)\n\x00\x06\x01\x01\x00\x08\x00PMCS\x00\x00'
             '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
-            '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00(\r\xc5\x0c\xb6',
+            '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00(',
             PAS5211EventFrameReceived(
                 length=48,
                 management_frame=PON_TRUE,
@@ -517,12 +516,13 @@
                         data=dict(
                             vendor_id="PMCS"
                         )
-                    )
+                    ),
+                    # omci_trailer=0x28
                 )
             ),
             channel_id=0, onu_id=0, onu_session_id=1
         )
-    '''
+
 
 if __name__ == '__main__':
     main()
diff --git a/tests/utests/voltha/extensions/omci/test_omci.py b/tests/utests/voltha/extensions/omci/test_omci.py
index 83d79db..c045d76 100644
--- a/tests/utests/voltha/extensions/omci/test_omci.py
+++ b/tests/utests/voltha/extensions/omci/test_omci.py
@@ -1,8 +1,26 @@
 from unittest import TestCase, main
-from voltha.extensions.omci.omci import CirtcuitPackEntity, bitpos_from_mask
+
+from hexdump import hexdump
+
+from voltha.extensions.omci.omci import CircuitPackEntity, bitpos_from_mask, \
+    OmciUninitializedFieldError, OMCIGetResponse, OMCIFrame, OMCIGetRequest
 from voltha.extensions.omci.omci import EntityClass
 
 
+def hexify(buffer):
+    """Return a hexadecimal string encoding of input buffer"""
+    return ''.join('%02x' % ord(c) for c in buffer)
+
+
+def chunk(indexable, chunk_size):
+    for i in range(0, len(indexable), chunk_size):
+        yield indexable[i : i + chunk_size]
+
+
+def hex2raw(hex_string):
+    return ''.join(chr(int(byte, 16)) for byte in chunk(hex_string, 2))
+
+
 class TestOmci(TestCase):
 
     def test_bitpos_from_mask(self):
@@ -32,15 +50,108 @@
 
     def test_entity_attribute_serialization(self):
 
-        e = CirtcuitPackEntity(vendor_id='F')
+        e = CircuitPackEntity(vendor_id='F')
         self.assertEqual(e.serialize(), 'F\x00\x00\x00')
 
-        e = CirtcuitPackEntity(vendor_id='FOOX')
+        e = CircuitPackEntity(vendor_id='FOOX')
         self.assertEqual(e.serialize(), 'FOOX')
 
-        e = CirtcuitPackEntity(vendor_id='FOOX', number_of_ports=16)
+        e = CircuitPackEntity(vendor_id='FOOX', number_of_ports=16)
         self.assertEqual(e.serialize(), '\x10FOOX')
 
+    def test_entity_attribute_serialization_mask_based(self):
+
+        e = CircuitPackEntity(
+            number_of_ports=4,
+            serial_number='123-123A',
+            version='a1c12fba91de',
+            vendor_id='BCM',
+            total_tcont_buffer_number=128
+        )
+
+        # Full object
+        self.assertEqual(e.serialize(),
+                         '\x04123-123Aa1c12fba91de\x00\x00BCM\x00\x80')
+
+        # Explicit mask with valid values
+        self.assertEqual(e.serialize(0x800), 'BCM\x00')
+        self.assertEqual(e.serialize(0x6800), '\x04123-123ABCM\x00')
+
+        # Referring to an unfilled field is regarded as error
+        self.assertRaises(OmciUninitializedFieldError, e.serialize, 0xc00)
+
+    def test_omci_mask_value_gen(self):
+        cls = CircuitPackEntity
+        self.assertEqual(cls.mask_for('vendor_id'), 0x800)
+        self.assertEqual(
+            cls.mask_for('vendor_id', 'bridged_or_ip_ind'), 0x900)
+
+    reference_get_request_hex = (
+        '00 00 49 0a'
+        '00 06 01 01'
+        '08 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 28'.replace(' ', '')
+    )
+    reference_get_request_raw = hex2raw(reference_get_request_hex)
+
+    reference_get_response_hex = (
+        '00 00 29 0a'
+        '00 06 01 01'
+        '00 08 00 50'
+        '4d 43 53 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 00'
+        '00 00 00 28'.replace(' ', '')
+    )
+    reference_get_response_raw = hex2raw(reference_get_response_hex)
+
+    def test_omci_frame_serialization(self):
+
+        frame = OMCIFrame(
+            transaction_id=0,
+            message_type=0x49,
+            omci_message=OMCIGetRequest(
+                entity_class=CircuitPackEntity.class_id,
+                entity_id=0x101,
+                attributes_mask=CircuitPackEntity.mask_for('vendor_id')
+            )
+        )
+        self.assertEqual(hexify(str(frame)), self.reference_get_request_hex)
+
+    def test_omci_frame_deserialization_no_data(self):
+        frame = OMCIFrame(self.reference_get_request_raw)
+        self.assertEqual(frame.transaction_id, 0)
+        self.assertEqual(frame.message_type, 0x49)
+        self.assertEqual(frame.omci, 10)
+        self.assertEqual(frame.omci_message.entity_class, 0x6)
+        self.assertEqual(frame.omci_message.entity_id, 0x101)
+        self.assertEqual(frame.omci_message.attributes_mask, 0x800)
+        self.assertEqual(frame.omci_trailer, 0x28)
+
+    def test_omci_frame_deserialization_with_data(self):
+        frame = OMCIFrame(self.reference_get_response_raw)
+        self.assertEqual(frame.transaction_id, 0)
+        self.assertEqual(frame.message_type, 0x29)
+        self.assertEqual(frame.omci, 10)
+        self.assertEqual(frame.omci_message.success_code, 0x0)
+        self.assertEqual(frame.omci_message.entity_class, 0x6)
+        self.assertEqual(frame.omci_message.entity_id, 0x101)
+        self.assertEqual(frame.omci_message.attributes_mask, 0x800)
+        self.assertEqual(frame.omci_trailer, 0x28)
+
+    def test_entity_attribute_deserialization(self):
+        pass
 
 if __name__ == '__main__':
     main()
diff --git a/voltha/extensions/omci/omci.py b/voltha/extensions/omci/omci.py
index aa14b92..8288308 100644
--- a/voltha/extensions/omci/omci.py
+++ b/voltha/extensions/omci/omci.py
@@ -1,12 +1,29 @@
 import inspect
 import sys
 from enum import Enum
-# from scapy.all import StrFixedLenField, ByteField, ShortField, ConditionalField, \
-#     PacketField, PadField, IntField, Field, Packet
 from scapy.fields import ByteField, Field, ShortField, PacketField, PadField, \
     ConditionalField
 from scapy.fields import StrFixedLenField, IntField
-from scapy.packet import Packet
+from scapy.packet import Packet, Raw
+
+
+class OmciUninitializedFieldError(Exception): pass
+
+
+class FixedLenField(PadField):
+    """
+    This Pad field limits parsing of its content to its size
+    """
+    def __init__(self, fld, align, padwith='\x00'):
+        super(FixedLenField, self).__init__(fld, align, padwith)
+
+    def getfield(self, pkt, s):
+        remain, val = self._fld.getfield(pkt, s[:self._align])
+        if isinstance(val.payload, Raw) and \
+                not val.payload.load.replace(self._padwith, ''):
+            # raw payload is just padding
+            val.remove_payload()
+        return remain + s[self._align:], val
 
 
 def bitpos_from_mask(mask, lsb_pos=0, increment=1):
@@ -48,34 +65,41 @@
     Test = 11
 
 
-class EntityClassAttribute:
+class EntityClassAttribute(object):
 
     def __init__(self, fld, access=set(), optional=False):
         self._fld = fld
         self._access = access
         self._optional = optional
 
-class EntityClass:
+class EntityClassMeta(type):
+    """
+    Metaclass for EntityClass to generate secondary class attributes
+    for class attributes of the derived classes.
+    """
+    def __init__(cls, name, bases, dct):
+        super(EntityClassMeta, cls).__init__(name, bases, dct)
+
+        # initialize attribute_name_to_index_map
+        cls.attribute_name_to_index_map = dict(
+            (a._fld.name, idx) for idx, a in enumerate(cls.attributes))
+
+
+class EntityClass(object):
+
     class_id = 'to be filled by subclass'
     attributes = []
     mandatory_operations = {}
     optional_operations = {}
 
-    # will be map of attr_name -> index in attributes
+    # will be map of attr_name -> index in attributes, initialized by metaclass
     attribute_name_to_index_map = None
+    __metaclass__ = EntityClassMeta
 
     def __init__(self, **kw):
-
         assert(isinstance(kw, dict))
-
-        # verify that all keys provided are valid in the entity
-        if self.attribute_name_to_index_map is None:
-            self.__class__.attribute_name_to_index_map = dict(
-                (a._fld.name, idx) for idx, a in enumerate(self.attributes))
-
         for k, v in kw.iteritems():
             assert(k in self.attribute_name_to_index_map)
-
         self._data = kw
 
     def serialize(self, mask=None, operation=None):
@@ -87,14 +111,18 @@
         # also taking into account the type of operation in hand
         if mask is not None:
             attribute_indices = EntityClass.attribute_indices_from_mask(mask)
-            print attribute_indices
         else:
             attribute_indices = self.attribute_indices_from_data()
 
         # Serialize each indexed field (ignoring entity id)
         for index in attribute_indices:
             field = self.attributes[index]._fld
-            bytes = field.addfield(None, bytes, self._data[field.name])
+            try:
+                value = self._data[field.name]
+            except KeyError:
+                raise OmciUninitializedFieldError(
+                    'Entity field "{}" not set'.format(field.name) )
+            bytes = field.addfield(None, bytes, value)
 
         return bytes
 
@@ -115,6 +143,19 @@
             cls.byte1_mask_to_attr_indices[(mask >> 8) & 0xff] + \
             cls.byte2_mask_to_attr_indices[(mask & 0xff)]
 
+    @classmethod
+    def mask_for(cls, *attr_names):
+        """
+        Return mask value corresponding to given attributes names
+        :param attr_names: Attribute names
+        :return: integer mask value
+        """
+        mask = 0
+        for attr_name in attr_names:
+            index = cls.attribute_name_to_index_map[attr_name]
+            mask |= (1 << (16 - index))
+        return mask
+
 
 # abbreviations
 ECA = EntityClassAttribute
@@ -122,7 +163,7 @@
 OP = EntityOperations
 
 
-class CirtcuitPackEntity(EntityClass):
+class CircuitPackEntity(EntityClass):
     class_id = 6
     attributes = [
         ECA(StrFixedLenField("managed_entity_id", None, 22), {AA.R, AA.SBC}),
@@ -167,20 +208,39 @@
         self._entity_class = entity_class
         self._attributes_mask = attributes_mask
 
-    def i2m(self, pkt, x):
+    def addfield(self, pkt, s, val):
         class_id = getattr(pkt, self._entity_class)
         attribute_mask = getattr(pkt, self._attributes_mask)
         entity_class = entity_id_to_class_map.get(class_id)
-        return entity_class(**x).serialize(attribute_mask)
+        indices = entity_class.attribute_indices_from_mask(attribute_mask)
+        for index in indices:
+            fld = entity_class.attributes[index]._fld
+            s = fld.addfield(pkt, s, val[fld.name])
+        return s
+
+    def getfield(self, pkt, s):
+        """Extract an internal value from a string"""
+        class_id = getattr(pkt, self._entity_class)
+        attribute_mask = getattr(pkt, self._attributes_mask)
+        entity_class = entity_id_to_class_map.get(class_id)
+        indices = entity_class.attribute_indices_from_mask(attribute_mask)
+        data = {}
+        for index in indices:
+            fld = entity_class.attributes[index]._fld
+            s, value = fld.getfield(pkt, s)
+            data[fld.name] = value
+        return  s, data
 
 
 class OMCIMessage(Packet):
     name = "OMCIMessage"
+    message_id = None  # OMCI message_type value, filled by derived classes
     fields_desc = []
 
 
 class OMCIGetRequest(OMCIMessage):
     name = "OMCIGetRequest"
+    message_id = 0x49
     fields_desc = [
         ShortField("entity_class", None),
         ShortField("entity_id", 0),
@@ -190,13 +250,16 @@
 
 class OMCIGetResponse(OMCIMessage):
     name = "OMCIGetResponse"
+    message_id = 0x29
     fields_desc = [
         ShortField("entity_class", None),
         ShortField("entity_id", 0),
         ByteField("success_code", 0),
         ShortField("attributes_mask", None),
-        OMCIData("data", entity_class="entity_class",
-                 attributes_mask="attributes_mask")
+        ConditionalField(
+            OMCIData("data", entity_class="entity_class",
+                     attributes_mask="attributes_mask"),
+            lambda pkt: pkt.success_code == 0)
     ]
 
 
@@ -206,12 +269,36 @@
         ShortField("transaction_id", 0),
         ByteField("message_type", None),
         ByteField("omci", 0x0a),
-        ConditionalField(PadField(PacketField("omci_message", None,
+        ConditionalField(FixedLenField(PacketField("omci_message", None,
                                               OMCIGetRequest), align=36),
                          lambda pkt: pkt.message_type == 0x49),
-        ConditionalField(PadField(PacketField("omci_message", None,
+        ConditionalField(FixedLenField(PacketField("omci_message", None,
                                               OMCIGetResponse), align=36),
                          lambda pkt: pkt.message_type == 0x29),
         # TODO add additional message types here as padded conditionals...
+
         IntField("omci_trailer", 0x00000028)
     ]
+
+    # We needed to patch the do_dissect(...) method of Packet, because
+    # it wiped out already dissected conditional fields with None if they
+    # referred to the same field name. We marked the only new line of code
+    # with "Extra condition added".
+    def do_dissect(self, s):
+        raw = s
+        self.raw_packet_cache_fields = {}
+        for f in self.fields_desc:
+            if not s:
+                break
+            s, fval = f.getfield(self, s)
+            # We need to track fields with mutable values to discard
+            # .raw_packet_cache when needed.
+            if f.islist or f.holds_packets:
+                self.raw_packet_cache_fields[f.name] = f.do_copy(fval)
+            # Extra condition added
+            if fval is not None or f.name not in self.fields:
+                self.fields[f.name] = fval
+        assert(raw.endswith(s))
+        self.raw_packet_cache = raw[:-len(s)] if s else raw
+        self.explicit = 1
+        return s