blob: 56e985b2e90d2b17d29b3c276adee720482b98ba [file] [log] [blame]
Chip Boling32aab302019-01-23 10:50:18 -06001#
2# Copyright 2017 the original author or authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16import binascii
17import json
18from scapy.fields import Field, StrFixedLenField, PadField, IntField, FieldListField, ByteField, StrField, \
19 StrFixedLenField, PacketField
20from scapy.packet import Raw
21
22class FixedLenField(PadField):
23 """
24 This Pad field limits parsing of its content to its size
25 """
26 def __init__(self, fld, align, padwith='\x00'):
27 super(FixedLenField, self).__init__(fld, align, padwith)
28
29 def getfield(self, pkt, s):
30 remain, val = self._fld.getfield(pkt, s[:self._align])
31 if isinstance(val.payload, Raw) and \
32 not val.payload.load.replace(self._padwith, ''):
33 # raw payload is just padding
34 val.remove_payload()
35 return remain + s[self._align:], val
36
37
38class StrCompoundField(Field):
39 __slots__ = ['flds']
40
41 def __init__(self, name, flds):
42 super(StrCompoundField, self).__init__(name=name, default=None, fmt='s')
43 self.flds = flds
44 for fld in self.flds:
45 assert not fld.holds_packets, 'compound field cannot have packet field members'
46
47 def addfield(self, pkt, s, val):
48 for fld in self.flds:
49 # run though fake add/get to consume the relevant portion of the input value for this field
50 x, extracted = fld.getfield(pkt, fld.addfield(pkt, '', val))
51 l = len(extracted)
52 s = fld.addfield(pkt, s, val[0:l])
53 val = val[l:]
54 return s;
55
56 def getfield(self, pkt, s):
57 data = ''
58 for fld in self.flds:
59 s, value = fld.getfield(pkt, s)
60 if not isinstance(value, str):
61 value = fld.i2repr(pkt, value)
62 data += value
63 return s, data
64
65
66class XStrFixedLenField(StrFixedLenField):
67 """
68 XStrFixedLenField which value is printed as hexadecimal.
69 """
70 def i2m(self, pkt, x):
71 l = self.length_from(pkt) * 2
72 return None if x is None else binascii.a2b_hex(x)[0:l+1]
73
74 def m2i(self, pkt, x):
75 return None if x is None else binascii.b2a_hex(x)
76
77
78class MultipleTypeField(object):
79 """MultipleTypeField are used for fields that can be implemented by
80 various Field subclasses, depending on conditions on the packet.
81
82 It is initialized with `flds` and `default`.
83
84 `default` is the default field type, to be used when none of the
85 conditions matched the current packet.
86
87 `flds` is a list of tuples (`fld`, `cond`), where `fld` if a field
88 type, and `cond` a "condition" to determine if `fld` is the field type
89 that should be used.
90
91 `cond` is either:
92
93 - a callable `cond_pkt` that accepts one argument (the packet) and
94 returns True if `fld` should be used, False otherwise.
95
96 - a tuple (`cond_pkt`, `cond_pkt_val`), where `cond_pkt` is the same
97 as in the previous case and `cond_pkt_val` is a callable that
98 accepts two arguments (the packet, and the value to be set) and
99 returns True if `fld` should be used, False otherwise.
100
101 See scapy.layers.l2.ARP (type "help(ARP)" in Scapy) for an example of
102 use.
103 """
104
105 __slots__ = ["flds", "default", "name"]
106
107 def __init__(self, flds, default):
108 self.flds = flds
109 self.default = default
110 self.name = self.default.name
111
112 def _find_fld_pkt(self, pkt):
113 """Given a Packet instance `pkt`, returns the Field subclass to be
114 used. If you know the value to be set (e.g., in .addfield()), use
115 ._find_fld_pkt_val() instead.
116 """
117 for fld, cond in self.flds:
118 if isinstance(cond, tuple):
119 cond = cond[0]
120 if cond(pkt):
121 return fld
122 return self.default
123
124 def _find_fld_pkt_val(self, pkt, val):
125 """Given a Packet instance `pkt` and the value `val` to be set,
126 returns the Field subclass to be used.
127 """
128 for fld, cond in self.flds:
129 if isinstance(cond, tuple):
130 if cond[1](pkt, val):
131 return fld
132 elif cond(pkt):
133 return fld
134 return self.default
135
136 def getfield(self, pkt, s):
137 return self._find_fld_pkt(pkt).getfield(pkt, s)
138
139 def addfield(self, pkt, s, val):
140 return self._find_fld_pkt_val(pkt, val).addfield(pkt, s, val)
141
142 def any2i(self, pkt, val):
143 return self._find_fld_pkt_val(pkt, val).any2i(pkt, val)
144
145 def h2i(self, pkt, val):
146 return self._find_fld_pkt_val(pkt, val).h2i(pkt, val)
147
148 def i2h(self, pkt, val):
149 return self._find_fld_pkt_val(pkt, val).i2h(pkt, val)
150
151 def i2m(self, pkt, val):
152 return self._find_fld_pkt_val(pkt, val).i2m(pkt, val)
153
154 def i2len(self, pkt, val):
155 return self._find_fld_pkt_val(pkt, val).i2len(pkt, val)
156
157 def i2repr(self, pkt, val):
158 return self._find_fld_pkt_val(pkt, val).i2repr(pkt, val)
159
160 def register_owner(self, cls):
161 for fld, _ in self.flds:
162 fld.owners.append(cls)
163 self.dflt.owners.append(cls)
164
165 def __getattr__(self, attr):
166 return getattr(self._find_fld(), attr)
167
168class OmciSerialNumberField(StrCompoundField):
169 def __init__(self, name, default=None):
170 assert default is None or (isinstance(default, str) and len(default) == 12), 'invalid default serial number'
171 vendor_default = default[0:4] if default is not None else None
172 vendor_serial_default = default[4:12] if default is not None else None
173 super(OmciSerialNumberField, self).__init__(name,
174 [StrFixedLenField('vendor_id', vendor_default, 4),
175 XStrFixedLenField('vendor_serial_number', vendor_serial_default, 4)])
176
177class OmciTableField(MultipleTypeField):
178 def __init__(self, tblfld):
179 assert isinstance(tblfld, PacketField)
180 assert hasattr(tblfld.cls, 'index'), 'No index() method defined for OmciTableField row object'
181 assert hasattr(tblfld.cls, 'is_delete'), 'No delete() method defined for OmciTableField row object'
182 super(OmciTableField, self).__init__(
183 [
184 (IntField('table_length', 0), (self.cond_pkt, self.cond_pkt_val)),
185 (PadField(StrField('me_type_table', None), OmciTableField.PDU_SIZE),
186 (self.cond_pkt2, self.cond_pkt_val2))
187 ], tblfld)
188
189 PDU_SIZE = 29 # Baseline message set raw get-next PDU size
190 OmciGetResponseMessageId = 0x29 # Ugh circular dependency
191 OmciGetNextResponseMessageId = 0x3a # Ugh circular dependency
192
193 def cond_pkt(self, pkt):
194 return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
195
196 def cond_pkt_val(self, pkt, val):
197 return pkt is not None and pkt.message_id == self.OmciGetResponseMessageId
198
199 def cond_pkt2(self, pkt):
200 return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
201
202 def cond_pkt_val2(self, pkt, val):
203 return pkt is not None and pkt.message_id == self.OmciGetNextResponseMessageId
204
205 def to_json(self, new_values, old_values_json):
206 if not isinstance(new_values, list): new_values = [new_values] # If setting a scalar, augment the old table
207 else: old_values_json = None # If setting a vector of new values, erase all old_values
208
209 key_value_pairs = dict()
210
211 old_table = self.load_json(old_values_json)
212 for old in old_table:
213 index = old.index()
214 key_value_pairs[index] = old
215 for new in new_values:
216 index = new.index()
217 if new.is_delete():
218 del key_value_pairs[index]
219 else:
220 key_value_pairs[index] = new
221
222 new_table = []
223 for k, v in sorted(key_value_pairs.iteritems()):
224 assert isinstance(v, self.default.cls), 'object type for Omci Table row object invalid'
225 new_table.append(v.fields)
226
227 str_values = json.dumps(new_table, separators=(',', ':'))
228
229 return str_values
230
231 def load_json(self, json_str):
232 if json_str is None: json_str = '[]'
233 json_values = json.loads(json_str)
234 key_value_pairs = dict()
235 for json_value in json_values:
236 v = self.default.cls(**json_value)
237 index = v.index()
238 key_value_pairs[index] = v
239 table = []
240 for k, v in sorted(key_value_pairs.iteritems()):
241 table.append(v)
242 return table