VOL-1482: Fix Scapy definition for OMCI GetResponse message

Change-Id: I155ff3f5914b81f9a09aede97c2a7cafc1b088fe
diff --git a/tests/utests/voltha/extensions/omci/test_omci.py b/tests/utests/voltha/extensions/omci/test_omci.py
index 4270399..0818dd6 100644
--- a/tests/utests/voltha/extensions/omci/test_omci.py
+++ b/tests/utests/voltha/extensions/omci/test_omci.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 #
 from unittest import TestCase, main
+from binascii import unhexlify
 
 from voltha.extensions.omci.omci import *
 
@@ -1136,7 +1137,7 @@
             message_type=OmciReboot.message_id,
             omci_message=OmciReboot(
                 entity_class=OntG.class_id,
-                 entity_id=0
+                entity_id=0
             )
         )
         self.assertGeneratedFrameEquals(frame, ref)
@@ -1157,6 +1158,128 @@
             self.assertTrue(AA.SBC not in mei_attr.access or
                             mei_attr.field.name == 'managed_entity_id')
 
+    def test_get_response_without_error_but_too_big(self):
+        # This test is related to a bug that I believe is in the BroadCom
+        # ONU stack software, or at least it was seen on both an Alpha and
+        # an T&W BCM-based onu.  The IEEE 802.1p Mapper Service Profile ME
+        # (#130) sent by the ONUs have a payload of 27 octets based on the
+        # Attribute Mask in the encoding.  However, get-response baseline
+        # messages have the last 4 octets reserved for failed/errored attribute
+        # masks so only 25 octets should be allowed.  Of course the 4 octets
+        # are only valid if the status code == 9, but they still should
+        # be reserved.
+        #
+        # This test verifies that we can still parse the 27 octet payload
+        # since the first rule of interoperability is to be lenient with
+        # what you receive and strict with what you transmit.
+        #
+        ref = '017d290a008280020000780000000000000000000000' +\
+              '0000000000000000000000000000' +\
+              '01' +\
+              '02' +\
+              '0000' +\
+              '00000028'
+        zeros_24 = '000000000000000000000000000000000000000000000000'
+        bytes_24 = unhexlify(zeros_24)
+        attributes = {
+            "unmarked_frame_option": 0,         # 1 octet
+            "dscp_to_p_bit_mapping": bytes_24,  # 24 octets
+            "default_p_bit_marking": 1,         # 1 octet   - This is too much
+            "tp_type": 2,                       # 1 octet
+        }
+        frame = OmciFrame(
+            transaction_id=0x017d,
+            message_type=OmciGetResponse.message_id,
+            omci_message=OmciGetResponse(
+                entity_class=Ieee8021pMapperServiceProfile.class_id,
+                success_code=0,
+                entity_id=0x8002,
+                attributes_mask=Ieee8021pMapperServiceProfile.mask_for(*attributes.keys()),
+                data=attributes
+            )
+        )
+        self.assertGeneratedFrameEquals(frame, ref)
+
+    def test_get_response_with_errors_max_data(self):
+        # First a frame with maximum data used up. This aligns the fields up perfectly
+        # with the simplest definition of a Get Response
+        ref = '017d290a008280020900600000000000000000000000' +\
+              '0000000000000000000000000000' +\
+              '0010' +\
+              '0008' +\
+              '00000028'
+        zeros_24 = '000000000000000000000000000000000000000000000000'
+        bytes_24 = unhexlify(zeros_24)
+        good_attributes = {
+            "unmarked_frame_option": 0,         # 1 octet
+            "dscp_to_p_bit_mapping": bytes_24,  # 24 octets
+        }
+        unsupported_attributes = ["default_p_bit_marking"]
+        failed_attributes_mask = ["tp_type"]
+
+        the_class = Ieee8021pMapperServiceProfile
+        frame = OmciFrame(
+            transaction_id=0x017d,
+            message_type=OmciGetResponse.message_id,
+            omci_message=OmciGetResponse(
+                entity_class=the_class.class_id,
+                success_code=9,
+                entity_id=0x8002,
+                attributes_mask=the_class.mask_for(*good_attributes.keys()),
+                unsupported_attributes_mask=the_class.mask_for(*unsupported_attributes),
+                failed_attributes_mask=the_class.mask_for(*failed_attributes_mask),
+                data=good_attributes
+            )
+        )
+        self.assertGeneratedFrameEquals(frame, ref)
+
+    def test_get_response_with_errors_min_data(self):
+        # Next a frame with only a little data used up. This aligns will require
+        # the encoder and decoder to skip to the last 8 octets of the data field
+        # and encode the failed masks there
+        ref = '017d290a00828002090040' +\
+              '01' + '00000000000000000000' +\
+              '0000000000000000000000000000' +\
+              '0010' +\
+              '0028' +\
+              '00000028'
+
+        good_attributes = {
+            "unmarked_frame_option": 1,         # 1 octet
+        }
+        unsupported_attributes = ["default_p_bit_marking"]
+        failed_attributes_mask = ["dscp_to_p_bit_mapping", "tp_type"]
+
+        the_class = Ieee8021pMapperServiceProfile
+        frame = OmciFrame(
+            transaction_id=0x017d,
+            message_type=OmciGetResponse.message_id,
+            omci_message=OmciGetResponse(
+                entity_class=the_class.class_id,
+                success_code=9,
+                entity_id=0x8002,
+                attributes_mask=the_class.mask_for(*good_attributes.keys()),
+                unsupported_attributes_mask=the_class.mask_for(*unsupported_attributes),
+                failed_attributes_mask=the_class.mask_for(*failed_attributes_mask),
+                data=good_attributes
+            )
+        )
+        self.assertGeneratedFrameEquals(frame, ref)
+
+        # Now test decode of the packet
+        decoded = OmciFrame(unhexlify(ref))
+
+        orig_fields = frame.fields['omci_message'].fields
+        omci_fields = decoded.fields['omci_message'].fields
+
+        for field in ['entity_class', 'entity_id', 'attributes_mask',
+                      'success_code', 'unsupported_attributes_mask',
+                      'failed_attributes_mask']:
+            self.assertEqual(omci_fields[field], orig_fields[field])
+
+        self.assertEqual(omci_fields['data']['unmarked_frame_option'],
+                         orig_fields['data']['unmarked_frame_option'])
+
 
 if __name__ == '__main__':
     main()
diff --git a/voltha/extensions/omci/omci_fields.py b/voltha/extensions/omci/omci_fields.py
index 09cf465..8fc8a4c 100644
--- a/voltha/extensions/omci/omci_fields.py
+++ b/voltha/extensions/omci/omci_fields.py
@@ -250,4 +250,31 @@
         for k, v in sorted(key_value_pairs.iteritems()):
             table.append(v)
 
-        return table
\ No newline at end of file
+        return table
+
+
+class OmciVariableLenZeroPadField(Field):
+    __slots__ = ["_pad_to", "_omci_hdr_len"]
+
+    def __init__(self, name, pad_to):
+        Field.__init__(self, name, 0, 'B')
+        self._pad_to = pad_to
+        self._omci_hdr_len = 4
+
+    def addfield(self, pkt, s, _val):
+        count = self._pad_to - self._omci_hdr_len - len(s)
+        if count < 0:
+            from scapy.error import Scapy_Exception
+            raise Scapy_Exception("%s: Already past pad_to offset" %
+                                  self.__class__.__name__)
+        padding = bytearray(count)
+        import struct
+        return s + struct.pack("%iB" % count, *padding)
+
+    def getfield(self, pkt, s):
+        count = len(s) - self._omci_hdr_len
+        if count < 0:
+            from scapy.error import Scapy_Exception
+            raise Scapy_Exception("%s: Already past pad_to offset" %
+                                  self.__class__.__name__)
+        return s[count:], s[-count:]
diff --git a/voltha/extensions/omci/omci_messages.py b/voltha/extensions/omci/omci_messages.py
index 5da57bf..70a5704 100644
--- a/voltha/extensions/omci/omci_messages.py
+++ b/voltha/extensions/omci/omci_messages.py
@@ -19,7 +19,7 @@
 from scapy.packet import Packet
 
 from voltha.extensions.omci.omci_defs import AttributeAccess, OmciSectionDataSize
-from voltha.extensions.omci.omci_fields import OmciTableField
+from voltha.extensions.omci.omci_fields import OmciTableField, OmciVariableLenZeroPadField
 import voltha.extensions.omci.omci_entities as omci_entities
 
 
@@ -198,15 +198,16 @@
         ShortField("entity_id", 0),
         ByteField("success_code", 0),
         ShortField("attributes_mask", None),
-        ConditionalField(
-            ShortField("unsupported_attributes_mask", 0),
-            lambda pkt: pkt.success_code == 9),
-        ConditionalField(
-            ShortField("failed_attributes_mask", 0),
-            lambda pkt: pkt.success_code == 9),
-        ConditionalField(
-            OmciMaskedData("data"),
-            lambda pkt: pkt.success_code == 0 or pkt.success_code == 9)
+        ConditionalField(OmciMaskedData("data"),
+                         lambda pkt: pkt.success_code in (0, 9)),
+        ConditionalField(OmciVariableLenZeroPadField("zero_padding", 36),
+                         lambda pkt: pkt.success_code == 9),
+
+        # These fields are only valid if attribute error (status == 9)
+        ConditionalField(ShortField("unsupported_attributes_mask", 0),
+                         lambda pkt: pkt.success_code == 9),
+        ConditionalField(ShortField("failed_attributes_mask", 0),
+                         lambda pkt: pkt.success_code == 9)
     ]
 
 
@@ -443,6 +444,7 @@
             OmciMaskedData("data"), lambda pkt: pkt.success_code == 0)
     ]
 
+
 class OmciStartSoftwareDownload(OmciMessage):
     name = "OmciStartSoftwareDownload"
     message_id = 0x53
@@ -455,6 +457,7 @@
         ShortField("instance_id", None) # should be same as "entity_id"        
     ]
 
+
 class OmciStartSoftwareDownloadResponse(OmciMessage):
     name = "OmciStartSoftwareDownloadResponse"
     message_id = 0x33
@@ -467,6 +470,7 @@
         ShortField("instance_id", None) # should be same as "entity_id"        
     ]
 
+
 class OmciEndSoftwareDownload(OmciMessage):
     name = "OmciEndSoftwareDownload"
     message_id = 0x55
@@ -479,6 +483,7 @@
         ShortField("instance_id", None),# should be same as "entity_id"
     ]
 
+
 class OmciEndSoftwareDownloadResponse(OmciMessage):
     name = "OmciEndSoftwareDownload"
     message_id = 0x35
@@ -491,6 +496,7 @@
         ByteField("result0", 0)         # same as result 
     ]
 
+
 class OmciDownloadSection(OmciMessage):
     name = "OmciDownloadSection"
     message_id = 0x14
@@ -501,6 +507,7 @@
         StrFixedLenField("data", 0, length=OmciSectionDataSize) # section data
     ]
 
+
 class OmciDownloadSectionLast(OmciMessage):
     name = "OmciDownloadSection"
     message_id = 0x54
@@ -511,6 +518,7 @@
         StrFixedLenField("data", 0, length=OmciSectionDataSize) # section data
     ]
 
+
 class OmciDownloadSectionResponse(OmciMessage):
     name = "OmciDownloadSectionResponse"
     message_id = 0x34
@@ -521,6 +529,7 @@
         ByteField("section_number", 0),  # Always only 1 in parallel
     ]
 
+
 class OmciActivateImage(OmciMessage):
     name = "OmciActivateImage"
     message_id = 0x56
@@ -530,6 +539,7 @@
         ByteField("activate_flag", 0)    # Activate image unconditionally
     ]
 
+
 class OmciActivateImageResponse(OmciMessage):
     name = "OmciActivateImageResponse"
     message_id = 0x36
@@ -539,6 +549,7 @@
         ByteField("result", 0)           # Activate image unconditionally
     ]
 
+
 class OmciCommitImage(OmciMessage):
     name = "OmciCommitImage"
     message_id = 0x57
@@ -547,6 +558,7 @@
         ShortField("entity_id", None),
     ]
 
+
 class OmciCommitImageResponse(OmciMessage):
     name = "OmciCommitImageResponse"
     message_id = 0x37
@@ -555,4 +567,3 @@
         ShortField("entity_id", None),
         ByteField("result", 0)           # Activate image unconditionally
     ]
-