blob: 56e985b2e90d2b17d29b3c276adee720482b98ba [file] [log] [blame]
#
# Copyright 2017 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import binascii
import json
from scapy.fields import Field, StrFixedLenField, PadField, IntField, FieldListField, ByteField, StrField, \
StrFixedLenField, PacketField
from scapy.packet import Raw
class FixedLenField(PadField):
"""
This Pad field limits parsing of its content to its size
"""
def __init__(self, fld, align, padwith='\x00'):
super(FixedLenField, self).__init__(fld, align, padwith)
def getfield(self, pkt, s):
remain, val = self._fld.getfield(pkt, s[:self._align])
if isinstance(val.payload, Raw) and \
not val.payload.load.replace(self._padwith, ''):
# raw payload is just padding
val.remove_payload()
return remain + s[self._align:], val
class StrCompoundField(Field):
__slots__ = ['flds']
def __init__(self, name, flds):
super(StrCompoundField, self).__init__(name=name, default=None, fmt='s')
self.flds = flds
for fld in self.flds:
assert not fld.holds_packets, 'compound field cannot have packet field members'
def addfield(self, pkt, s, val):
for fld in self.flds:
# run though fake add/get to consume the relevant portion of the input value for this field
x, extracted = fld.getfield(pkt, fld.addfield(pkt, '', val))
l = len(extracted)
s = fld.addfield(pkt, s, val[0:l])
val = val[l:]
return s;
def getfield(self, pkt, s):
data = ''
for fld in self.flds:
s, value = fld.getfield(pkt, s)
if not isinstance(value, str):
value = fld.i2repr(pkt, value)
data += value
return s, data
class XStrFixedLenField(StrFixedLenField):
"""
XStrFixedLenField which value is printed as hexadecimal.
"""
def i2m(self, pkt, x):
l = self.length_from(pkt) * 2
return None if x is None else binascii.a2b_hex(x)[0:l+1]
def m2i(self, pkt, x):
return None if x is None else binascii.b2a_hex(x)
class MultipleTypeField(object):
"""MultipleTypeField are used for fields that can be implemented by
various Field subclasses, depending on conditions on the packet.
It is initialized with `flds` and `default`.
`default` is the default field type, to be used when none of the
conditions matched the current packet.
`flds` is a list of tuples (`fld`, `cond`), where `fld` if a field
type, and `cond` a "condition" to determine if `fld` is the field type
that should be used.
`cond` is either:
- a callable `cond_pkt` that accepts one argument (the packet) and
returns True if `fld` should be used, False otherwise.
- a tuple (`cond_pkt`, `cond_pkt_val`), where `cond_pkt` is the same
as in the previous case and `cond_pkt_val` is a callable that
accepts two arguments (the packet, and the value to be set) and
returns True if `fld` should be used, False otherwise.
See scapy.layers.l2.ARP (type "help(ARP)" in Scapy) for an example of
use.
"""
__slots__ = ["flds", "default", "name"]
def __init__(self, flds, default):
self.flds = flds
self.default = default
self.name = self.default.name
def _find_fld_pkt(self, pkt):
"""Given a Packet instance `pkt`, returns the Field subclass to be
used. If you know the value to be set (e.g., in .addfield()), use
._find_fld_pkt_val() instead.
"""
for fld, cond in self.flds:
if isinstance(cond, tuple):
cond = cond[0]
if cond(pkt):
return fld
return self.default
def _find_fld_pkt_val(self, pkt, val):
"""Given a Packet instance `pkt` and the value `val` to be set,
returns the Field subclass to be used.
"""
for fld, cond in self.flds:
if isinstance(cond, tuple):
if cond[1](pkt, val):
return fld
elif cond(pkt):
return fld
return self.default
def getfield(self, pkt, s):
return self._find_fld_pkt(pkt).getfield(pkt, s)
def addfield(self, pkt, s, val):
return self._find_fld_pkt_val(pkt, val).addfield(pkt, s, val)
def any2i(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).any2i(pkt, val)
def h2i(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).h2i(pkt, val)
def i2h(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).i2h(pkt, val)
def i2m(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).i2m(pkt, val)
def i2len(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).i2len(pkt, val)
def i2repr(self, pkt, val):
return self._find_fld_pkt_val(pkt, val).i2repr(pkt, val)
def register_owner(self, cls):
for fld, _ in self.flds:
fld.owners.append(cls)
self.dflt.owners.append(cls)
def __getattr__(self, attr):
return getattr(self._find_fld(), attr)
class OmciSerialNumberField(StrCompoundField):
def __init__(self, name, default=None):
assert default is None or (isinstance(default, str) and len(default) == 12), 'invalid default serial number'
vendor_default = default[0:4] if default is not None else None
vendor_serial_default = default[4:12] if default is not None else None
super(OmciSerialNumberField, self).__init__(name,
[StrFixedLenField('vendor_id', vendor_default, 4),
XStrFixedLenField('vendor_serial_number', vendor_serial_default, 4)])
class OmciTableField(MultipleTypeField):
def __init__(self, tblfld):
assert isinstance(tblfld, PacketField)
assert hasattr(tblfld.cls, 'index'), 'No index() method defined for OmciTableField row object'
assert hasattr(tblfld.cls, 'is_delete'), 'No delete() method defined for OmciTableField row object'
super(OmciTableField, self).__init__(
[
(IntField('table_length', 0), (self.cond_pkt, self.cond_pkt_val)),
(PadField(StrField('me_type_table', None), OmciTableField.PDU_SIZE),
(self.cond_pkt2, self.cond_pkt_val2))
], tblfld)
PDU_SIZE = 29 # Baseline message set raw get-next PDU size
OmciGetResponseMessageId = 0x29 # Ugh circular dependency
OmciGetNextResponseMessageId = 0x3a # Ugh circular dependency
def cond_pkt(self, pkt):
return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
def cond_pkt_val(self, pkt, val):
return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
def cond_pkt2(self, pkt):
return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
def cond_pkt_val2(self, pkt, val):
return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
def to_json(self, new_values, old_values_json):
if not isinstance(new_values, list): new_values = [new_values] # If setting a scalar, augment the old table
else: old_values_json = None # If setting a vector of new values, erase all old_values
key_value_pairs = dict()
old_table = self.load_json(old_values_json)
for old in old_table:
index = old.index()
key_value_pairs[index] = old
for new in new_values:
index = new.index()
if new.is_delete():
del key_value_pairs[index]
else:
key_value_pairs[index] = new
new_table = []
for k, v in sorted(key_value_pairs.iteritems()):
assert isinstance(v, self.default.cls), 'object type for Omci Table row object invalid'
new_table.append(v.fields)
str_values = json.dumps(new_table, separators=(',', ':'))
return str_values
def load_json(self, json_str):
if json_str is None: json_str = '[]'
json_values = json.loads(json_str)
key_value_pairs = dict()
for json_value in json_values:
v = self.default.cls(**json_value)
index = v.index()
key_value_pairs[index] = v
table = []
for k, v in sorted(key_value_pairs.iteritems()):
table.append(v)
return table