blob: e2abff04a9723326c1349ca7cb7e9b547994821e [file] [log] [blame]
The Android Open Source Projectcf31fe92008-10-21 07:00:00 -07001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# http://code.google.com/p/protobuf/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53import heapq
54import threading
55import weakref
56# We use "as" to avoid name collisions with variables.
57from froofle.protobuf.internal import decoder
58from froofle.protobuf.internal import encoder
59from froofle.protobuf.internal import message_listener as message_listener_mod
60from froofle.protobuf.internal import type_checkers
61from froofle.protobuf.internal import wire_format
62from froofle.protobuf import descriptor as descriptor_mod
63from froofle.protobuf import message as message_mod
64
65_FieldDescriptor = descriptor_mod.FieldDescriptor
66
67
68class GeneratedProtocolMessageType(type):
69
70 """Metaclass for protocol message classes created at runtime from Descriptors.
71
72 We add implementations for all methods described in the Message class. We
73 also create properties to allow getting/setting all fields in the protocol
74 message. Finally, we create slots to prevent users from accidentally
75 "setting" nonexistent fields in the protocol message, which then wouldn't get
76 serialized / deserialized properly.
77
78 The protocol compiler currently uses this metaclass to create protocol
79 message classes at runtime. Clients can also manually create their own
80 classes at runtime, as in this example:
81
82 mydescriptor = Descriptor(.....)
83 class MyProtoClass(Message):
84 __metaclass__ = GeneratedProtocolMessageType
85 DESCRIPTOR = mydescriptor
86 myproto_instance = MyProtoClass()
87 myproto.foo_field = 23
88 ...
89 """
90
91 # Must be consistent with the protocol-compiler code in
92 # proto2/compiler/internal/generator.*.
93 _DESCRIPTOR_KEY = 'DESCRIPTOR'
94
95 def __new__(cls, name, bases, dictionary):
96 """Custom allocation for runtime-generated class types.
97
98 We override __new__ because this is apparently the only place
99 where we can meaningfully set __slots__ on the class we're creating(?).
100 (The interplay between metaclasses and slots is not very well-documented).
101
102 Args:
103 name: Name of the class (ignored, but required by the
104 metaclass protocol).
105 bases: Base classes of the class we're constructing.
106 (Should be message.Message). We ignore this field, but
107 it's required by the metaclass protocol
108 dictionary: The class dictionary of the class we're
109 constructing. dictionary[_DESCRIPTOR_KEY] must contain
110 a Descriptor object describing this protocol message
111 type.
112
113 Returns:
114 Newly-allocated class.
115 """
116 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
117 _AddSlots(descriptor, dictionary)
118 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
119 superclass = super(GeneratedProtocolMessageType, cls)
120 return superclass.__new__(cls, name, bases, dictionary)
121
122 def __init__(cls, name, bases, dictionary):
123 """Here we perform the majority of our work on the class.
124 We add enum getters, an __init__ method, implementations
125 of all Message methods, and properties for all fields
126 in the protocol type.
127
128 Args:
129 name: Name of the class (ignored, but required by the
130 metaclass protocol).
131 bases: Base classes of the class we're constructing.
132 (Should be message.Message). We ignore this field, but
133 it's required by the metaclass protocol
134 dictionary: The class dictionary of the class we're
135 constructing. dictionary[_DESCRIPTOR_KEY] must contain
136 a Descriptor object describing this protocol message
137 type.
138 """
139 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
140 # We act as a "friend" class of the descriptor, setting
141 # its _concrete_class attribute the first time we use a
142 # given descriptor to initialize a concrete protocol message
143 # class.
144 concrete_class_attr_name = '_concrete_class'
145 if not hasattr(descriptor, concrete_class_attr_name):
146 setattr(descriptor, concrete_class_attr_name, cls)
147 cls._known_extensions = []
148 _AddEnumValues(descriptor, cls)
149 _AddInitMethod(descriptor, cls)
150 _AddPropertiesForFields(descriptor, cls)
151 _AddStaticMethods(cls)
152 _AddMessageMethods(descriptor, cls)
153 _AddPrivateHelperMethods(cls)
154 superclass = super(GeneratedProtocolMessageType, cls)
155 superclass.__init__(cls, name, bases, dictionary)
156
157
158# Stateless helpers for GeneratedProtocolMessageType below.
159# Outside clients should not access these directly.
160#
161# I opted not to make any of these methods on the metaclass, to make it more
162# clear that I'm not really using any state there and to keep clients from
163# thinking that they have direct access to these construction helpers.
164
165
166def _PropertyName(proto_field_name):
167 """Returns the name of the public property attribute which
168 clients can use to get and (in some cases) set the value
169 of a protocol message field.
170
171 Args:
172 proto_field_name: The protocol message field name, exactly
173 as it appears (or would appear) in a .proto file.
174 """
175 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
176 # nnorwitz makes my day by writing:
177 # """
178 # FYI. See the keyword module in the stdlib. This could be as simple as:
179 #
180 # if keyword.iskeyword(proto_field_name):
181 # return proto_field_name + "_"
182 # return proto_field_name
183 # """
184 return proto_field_name
185
186
187def _ValueFieldName(proto_field_name):
188 """Returns the name of the (internal) instance attribute which objects
189 should use to store the current value for a given protocol message field.
190
191 Args:
192 proto_field_name: The protocol message field name, exactly
193 as it appears (or would appear) in a .proto file.
194 """
195 return '_value_' + proto_field_name
196
197
198def _HasFieldName(proto_field_name):
199 """Returns the name of the (internal) instance attribute which
200 objects should use to store a boolean telling whether this field
201 is explicitly set or not.
202
203 Args:
204 proto_field_name: The protocol message field name, exactly
205 as it appears (or would appear) in a .proto file.
206 """
207 return '_has_' + proto_field_name
208
209
210def _AddSlots(message_descriptor, dictionary):
211 """Adds a __slots__ entry to dictionary, containing the names of all valid
212 attributes for this message type.
213
214 Args:
215 message_descriptor: A Descriptor instance describing this message type.
216 dictionary: Class dictionary to which we'll add a '__slots__' entry.
217 """
218 field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields]
219 field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields
220 if f.label != _FieldDescriptor.LABEL_REPEATED)
221 field_names.extend(('Extensions',
222 '_cached_byte_size',
223 '_cached_byte_size_dirty',
224 '_called_transition_to_nonempty',
225 '_listener',
226 '_lock', '__weakref__'))
227 dictionary['__slots__'] = field_names
228
229
230def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
231 extension_dict = descriptor.extensions_by_name
232 for extension_name, extension_field in extension_dict.iteritems():
233 assert extension_name not in dictionary
234 dictionary[extension_name] = extension_field
235
236
237def _AddEnumValues(descriptor, cls):
238 """Sets class-level attributes for all enum fields defined in this message.
239
240 Args:
241 descriptor: Descriptor object for this message type.
242 cls: Class we're constructing for this message type.
243 """
244 for enum_type in descriptor.enum_types:
245 for enum_value in enum_type.values:
246 setattr(cls, enum_value.name, enum_value.number)
247
248
249def _DefaultValueForField(message, field):
250 """Returns a default value for a field.
251
252 Args:
253 message: Message instance containing this field, or a weakref proxy
254 of same.
255 field: FieldDescriptor object for this field.
256
257 Returns: A default value for this field. May refer back to |message|
258 via a weak reference.
259 """
260 # TODO(robinson): Only the repeated fields need a reference to 'message' (so
261 # that they can set the 'has' bit on the containing Message when someone
262 # append()s a value). We could special-case this, and avoid an extra
263 # function call on __init__() and Clear() for non-repeated fields.
264
265 # TODO(robinson): Find a better place for the default value assertion in this
266 # function. No need to repeat them every time the client calls Clear('foo').
267 # (We should probably just assert these things once and as early as possible,
268 # by tightening checking in the descriptor classes.)
269 if field.label == _FieldDescriptor.LABEL_REPEATED:
270 if field.default_value != []:
271 raise ValueError('Repeated field default value not empty list: %s' % (
272 field.default_value))
273 listener = _Listener(message, None)
274 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
275 # We can't look at _concrete_class yet since it might not have
276 # been set. (Depends on order in which we initialize the classes).
277 return _RepeatedCompositeFieldContainer(listener, field.message_type)
278 else:
279 return _RepeatedScalarFieldContainer(
280 listener, type_checkers.GetTypeChecker(field.cpp_type, field.type))
281
282 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
283 assert field.default_value is None
284
285 return field.default_value
286
287
288def _AddInitMethod(message_descriptor, cls):
289 """Adds an __init__ method to cls."""
290 fields = message_descriptor.fields
291 def init(self):
292 self._cached_byte_size = 0
293 self._cached_byte_size_dirty = False
294 self._listener = message_listener_mod.NullMessageListener()
295 self._called_transition_to_nonempty = False
296 # TODO(robinson): We should only create a lock if we really need one
297 # in this class.
298 self._lock = threading.Lock()
299 for field in fields:
300 default_value = _DefaultValueForField(self, field)
301 python_field_name = _ValueFieldName(field.name)
302 setattr(self, python_field_name, default_value)
303 if field.label != _FieldDescriptor.LABEL_REPEATED:
304 setattr(self, _HasFieldName(field.name), False)
305 self.Extensions = _ExtensionDict(self, cls._known_extensions)
306
307 init.__module__ = None
308 init.__doc__ = None
309 cls.__init__ = init
310
311
312def _AddPropertiesForFields(descriptor, cls):
313 """Adds properties for all fields in this protocol message type."""
314 for field in descriptor.fields:
315 _AddPropertiesForField(field, cls)
316
317
318def _AddPropertiesForField(field, cls):
319 """Adds a public property for a protocol message field.
320 Clients can use this property to get and (in the case
321 of non-repeated scalar fields) directly set the value
322 of a protocol message field.
323
324 Args:
325 field: A FieldDescriptor for this field.
326 cls: The class we're constructing.
327 """
328 # Catch it if we add other types that we should
329 # handle specially here.
330 assert _FieldDescriptor.MAX_CPPTYPE == 10
331
332 if field.label == _FieldDescriptor.LABEL_REPEATED:
333 _AddPropertiesForRepeatedField(field, cls)
334 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
335 _AddPropertiesForNonRepeatedCompositeField(field, cls)
336 else:
337 _AddPropertiesForNonRepeatedScalarField(field, cls)
338
339
340def _AddPropertiesForRepeatedField(field, cls):
341 """Adds a public property for a "repeated" protocol message field. Clients
342 can use this property to get the value of the field, which will be either a
343 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
344 below).
345
346 Note that when clients add values to these containers, we perform
347 type-checking in the case of repeated scalar fields, and we also set any
348 necessary "has" bits as a side-effect.
349
350 Args:
351 field: A FieldDescriptor for this field.
352 cls: The class we're constructing.
353 """
354 proto_field_name = field.name
355 python_field_name = _ValueFieldName(proto_field_name)
356 property_name = _PropertyName(proto_field_name)
357
358 def getter(self):
359 return getattr(self, python_field_name)
360 getter.__module__ = None
361 getter.__doc__ = 'Getter for %s.' % proto_field_name
362
363 # We define a setter just so we can throw an exception with a more
364 # helpful error message.
365 def setter(self, new_value):
366 raise AttributeError('Assignment not allowed to repeated field '
367 '"%s" in protocol message object.' % proto_field_name)
368
369 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
370 setattr(cls, property_name, property(getter, setter, doc=doc))
371
372
373def _AddPropertiesForNonRepeatedScalarField(field, cls):
374 """Adds a public property for a nonrepeated, scalar protocol message field.
375 Clients can use this property to get and directly set the value of the field.
376 Note that when the client sets the value of a field by using this property,
377 all necessary "has" bits are set as a side-effect, and we also perform
378 type-checking.
379
380 Args:
381 field: A FieldDescriptor for this field.
382 cls: The class we're constructing.
383 """
384 proto_field_name = field.name
385 python_field_name = _ValueFieldName(proto_field_name)
386 has_field_name = _HasFieldName(proto_field_name)
387 property_name = _PropertyName(proto_field_name)
388 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
389
390 def getter(self):
391 return getattr(self, python_field_name)
392 getter.__module__ = None
393 getter.__doc__ = 'Getter for %s.' % proto_field_name
394 def setter(self, new_value):
395 type_checker.CheckValue(new_value)
396 setattr(self, has_field_name, True)
397 self._MarkByteSizeDirty()
398 self._MaybeCallTransitionToNonemptyCallback()
399 setattr(self, python_field_name, new_value)
400 setter.__module__ = None
401 setter.__doc__ = 'Setter for %s.' % proto_field_name
402
403 # Add a property to encapsulate the getter/setter.
404 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
405 setattr(cls, property_name, property(getter, setter, doc=doc))
406
407
408def _AddPropertiesForNonRepeatedCompositeField(field, cls):
409 """Adds a public property for a nonrepeated, composite protocol message field.
410 A composite field is a "group" or "message" field.
411
412 Clients can use this property to get the value of the field, but cannot
413 assign to the property directly.
414
415 Args:
416 field: A FieldDescriptor for this field.
417 cls: The class we're constructing.
418 """
419 # TODO(robinson): Remove duplication with similar method
420 # for non-repeated scalars.
421 proto_field_name = field.name
422 python_field_name = _ValueFieldName(proto_field_name)
423 has_field_name = _HasFieldName(proto_field_name)
424 property_name = _PropertyName(proto_field_name)
425 message_type = field.message_type
426
427 def getter(self):
428 # TODO(robinson): Appropriately scary note about double-checked locking.
429 field_value = getattr(self, python_field_name)
430 if field_value is None:
431 self._lock.acquire()
432 try:
433 field_value = getattr(self, python_field_name)
434 if field_value is None:
435 field_class = message_type._concrete_class
436 field_value = field_class()
437 field_value._SetListener(_Listener(self, has_field_name))
438 setattr(self, python_field_name, field_value)
439 finally:
440 self._lock.release()
441 return field_value
442 getter.__module__ = None
443 getter.__doc__ = 'Getter for %s.' % proto_field_name
444
445 # We define a setter just so we can throw an exception with a more
446 # helpful error message.
447 def setter(self, new_value):
448 raise AttributeError('Assignment not allowed to composite field '
449 '"%s" in protocol message object.' % proto_field_name)
450
451 # Add a property to encapsulate the getter.
452 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
453 setattr(cls, property_name, property(getter, setter, doc=doc))
454
455
456def _AddStaticMethods(cls):
457 # TODO(robinson): This probably needs to be thread-safe(?)
458 def RegisterExtension(extension_handle):
459 extension_handle.containing_type = cls.DESCRIPTOR
460 cls._known_extensions.append(extension_handle)
461 cls.RegisterExtension = staticmethod(RegisterExtension)
462
463
464def _AddListFieldsMethod(message_descriptor, cls):
465 """Helper for _AddMessageMethods()."""
466
467 # Ensure that we always list in ascending field-number order.
468 # For non-extension fields, we can do the sort once, here, at import-time.
469 # For extensions, we sort on each ListFields() call, though
470 # we could do better if we have to.
471 fields = sorted(message_descriptor.fields, key=lambda f: f.number)
472 has_field_names = (_HasFieldName(f.name) for f in fields)
473 value_field_names = (_ValueFieldName(f.name) for f in fields)
474 triplets = zip(has_field_names, value_field_names, fields)
475
476 def ListFields(self):
477 # We need to list all extension and non-extension fields
478 # together, in sorted order by field number.
479
480 # Step 0: Get an iterator over all "set" non-extension fields,
481 # sorted by field number.
482 # This iterator yields (field_number, field_descriptor, value) tuples.
483 def SortedSetFieldsIter():
484 # Note that triplets is already sorted by field number.
485 for has_field_name, value_field_name, field_descriptor in triplets:
486 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
487 value = getattr(self, _ValueFieldName(field_descriptor.name))
488 if len(value) > 0:
489 yield (field_descriptor.number, field_descriptor, value)
490 elif getattr(self, _HasFieldName(field_descriptor.name)):
491 value = getattr(self, _ValueFieldName(field_descriptor.name))
492 yield (field_descriptor.number, field_descriptor, value)
493 sorted_fields = SortedSetFieldsIter()
494
495 # Step 1: Get an iterator over all "set" extension fields,
496 # sorted by field number.
497 # This iterator ALSO yields (field_number, field_descriptor, value) tuples.
498 # TODO(robinson): It's not necessary to repeat this with each
499 # serialization call. We can do better.
500 sorted_extension_fields = sorted(
501 [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()])
502
503 # Step 2: Create a composite iterator that merges the extension-
504 # and non-extension fields, and that still yields fields in
505 # sorted order.
506 all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields)
507
508 # Step 3: Strip off the field numbers and return.
509 return [field[1:] for field in all_set_fields]
510
511 cls.ListFields = ListFields
512
513def _AddHasFieldMethod(cls):
514 """Helper for _AddMessageMethods()."""
515 def HasField(self, field_name):
516 try:
517 return getattr(self, _HasFieldName(field_name))
518 except AttributeError:
519 raise ValueError('Protocol message has no "%s" field.' % field_name)
520 cls.HasField = HasField
521
522
523def _AddClearFieldMethod(cls):
524 """Helper for _AddMessageMethods()."""
525 def ClearField(self, field_name):
526 try:
527 field = self.DESCRIPTOR.fields_by_name[field_name]
528 except KeyError:
529 raise ValueError('Protocol message has no "%s" field.' % field_name)
530 proto_field_name = field.name
531 python_field_name = _ValueFieldName(proto_field_name)
532 has_field_name = _HasFieldName(proto_field_name)
533 default_value = _DefaultValueForField(self, field)
534 if field.label == _FieldDescriptor.LABEL_REPEATED:
535 self._MarkByteSizeDirty()
536 else:
537 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
538 old_field_value = getattr(self, python_field_name)
539 if old_field_value is not None:
540 # Snip the old object out of the object tree.
541 old_field_value._SetListener(None)
542 if getattr(self, has_field_name):
543 setattr(self, has_field_name, False)
544 # Set dirty bit on ourself and parents only if
545 # we're actually changing state.
546 self._MarkByteSizeDirty()
547 setattr(self, python_field_name, default_value)
548 cls.ClearField = ClearField
549
550
551def _AddClearExtensionMethod(cls):
552 """Helper for _AddMessageMethods()."""
553 def ClearExtension(self, extension_handle):
554 self.Extensions._ClearExtension(extension_handle)
555 cls.ClearExtension = ClearExtension
556
557
558def _AddClearMethod(cls):
559 """Helper for _AddMessageMethods()."""
560 def Clear(self):
561 # Clear fields.
562 fields = self.DESCRIPTOR.fields
563 for field in fields:
564 self.ClearField(field.name)
565 # Clear extensions.
566 extensions = self.Extensions._ListSetExtensions()
567 for extension in extensions:
568 self.ClearExtension(extension[0])
569 cls.Clear = Clear
570
571
572def _AddHasExtensionMethod(cls):
573 """Helper for _AddMessageMethods()."""
574 def HasExtension(self, extension_handle):
575 return self.Extensions._HasExtension(extension_handle)
576 cls.HasExtension = HasExtension
577
578
579def _AddEqualsMethod(message_descriptor, cls):
580 """Helper for _AddMessageMethods()."""
581 def __eq__(self, other):
582 if self is other:
583 return True
584
585 # Compare all fields contained directly in this message.
586 for field_descriptor in message_descriptor.fields:
587 label = field_descriptor.label
588 property_name = _PropertyName(field_descriptor.name)
589 # Non-repeated field equality requires matching "has" bits as well
590 # as having an equal value.
591 if label != _FieldDescriptor.LABEL_REPEATED:
592 self_has = self.HasField(property_name)
593 other_has = other.HasField(property_name)
594 if self_has != other_has:
595 return False
596 if not self_has:
597 # If the "has" bit for this field is False, we must stop here.
598 # Otherwise we will recurse forever on recursively-defined protos.
599 continue
600 if getattr(self, property_name) != getattr(other, property_name):
601 return False
602
603 # Compare the extensions present in both messages.
604 return self.Extensions == other.Extensions
605 cls.__eq__ = __eq__
606
607
608def _AddSetListenerMethod(cls):
609 """Helper for _AddMessageMethods()."""
610 def SetListener(self, listener):
611 if listener is None:
612 self._listener = message_listener_mod.NullMessageListener()
613 else:
614 self._listener = listener
615 cls._SetListener = SetListener
616
617
618def _BytesForNonRepeatedElement(value, field_number, field_type):
619 """Returns the number of bytes needed to serialize a non-repeated element.
620 The returned byte count includes space for tag information and any
621 other additional space associated with serializing value.
622
623 Args:
624 value: Value we're serializing.
625 field_number: Field number of this value. (Since the field number
626 is stored as part of a varint-encoded tag, this has an impact
627 on the total bytes required to serialize the value).
628 field_type: The type of the field. One of the TYPE_* constants
629 within FieldDescriptor.
630 """
631 try:
632 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
633 return fn(field_number, value)
634 except KeyError:
635 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
636
637
638def _AddByteSizeMethod(message_descriptor, cls):
639 """Helper for _AddMessageMethods()."""
640
641 def BytesForField(message, field, value):
642 """Returns the number of bytes required to serialize a single field
643 in message. The field may be repeated or not, composite or not.
644
645 Args:
646 message: The Message instance containing a field of the given type.
647 field: A FieldDescriptor describing the field of interest.
648 value: The value whose byte size we're interested in.
649
650 Returns: The number of bytes required to serialize the current value
651 of "field" in "message", including space for tags and any other
652 necessary information.
653 """
654
655 if _MessageSetField(field):
656 return wire_format.MessageSetItemByteSize(field.number, value)
657
658 field_number, field_type = field.number, field.type
659
660 # Repeated fields.
661 if field.label == _FieldDescriptor.LABEL_REPEATED:
662 elements = value
663 else:
664 elements = [value]
665
666 size = sum(_BytesForNonRepeatedElement(element, field_number, field_type)
667 for element in elements)
668 return size
669
670 fields = message_descriptor.fields
671 has_field_names = (_HasFieldName(f.name) for f in fields)
672 zipped = zip(has_field_names, fields)
673
674 def ByteSize(self):
675 if not self._cached_byte_size_dirty:
676 return self._cached_byte_size
677
678 size = 0
679 # Hardcoded fields first.
680 for has_field_name, field in zipped:
681 if (field.label == _FieldDescriptor.LABEL_REPEATED
682 or getattr(self, has_field_name)):
683 value = getattr(self, _ValueFieldName(field.name))
684 size += BytesForField(self, field, value)
685 # Extensions next.
686 for field, value in self.Extensions._ListSetExtensions():
687 size += BytesForField(self, field, value)
688
689 self._cached_byte_size = size
690 self._cached_byte_size_dirty = False
691 return size
692 cls.ByteSize = ByteSize
693
694
695def _MessageSetField(field_descriptor):
696 """Checks if a field should be serialized using the message set wire format.
697
698 Args:
699 field_descriptor: Descriptor of the field.
700
701 Returns:
702 True if the field should be serialized using the message set wire format,
703 false otherwise.
704 """
705 return (field_descriptor.is_extension and
706 field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and
707 field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
708 field_descriptor.containing_type.GetOptions().message_set_wire_format)
709
710
711def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder):
712 """Appends the serialization of a single value to encoder.
713
714 Args:
715 value: Value to serialize.
716 field_number: Field number of this value.
717 field_descriptor: Descriptor of the field to serialize.
718 encoder: encoder.Encoder object to which we should serialize this value.
719 """
720 if _MessageSetField(field_descriptor):
721 encoder.AppendMessageSetItem(field_number, value)
722 return
723
724 try:
725 method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type]
726 method(encoder, field_number, value)
727 except KeyError:
728 raise message_mod.EncodeError('Unrecognized field type: %d' %
729 field_descriptor.type)
730
731
732def _ImergeSorted(*streams):
733 """Merges N sorted iterators into a single sorted iterator.
734 Each element in streams must be an iterable that yields
735 its elements in sorted order, and the elements contained
736 in each stream must all be comparable.
737
738 There may be repeated elements in the component streams or
739 across the streams; the repeated elements will all be repeated
740 in the merged iterator as well.
741
742 I believe that the heapq module at HEAD in the Python
743 sources has a method like this, but for now we roll our own.
744 """
745 iters = [iter(stream) for stream in streams]
746 heap = []
747 for index, it in enumerate(iters):
748 try:
749 heap.append((it.next(), index))
750 except StopIteration:
751 pass
752 heapq.heapify(heap)
753
754 while heap:
755 smallest_value, idx = heap[0]
756 yield smallest_value
757 try:
758 next_element = iters[idx].next()
759 heapq.heapreplace(heap, (next_element, idx))
760 except StopIteration:
761 heapq.heappop(heap)
762
763
764def _AddSerializeToStringMethod(message_descriptor, cls):
765 """Helper for _AddMessageMethods()."""
766
767 def SerializeToString(self):
768 # Check if the message has all of its required fields set.
769 errors = []
770 if not _InternalIsInitialized(self, errors):
771 raise message_mod.EncodeError('\n'.join(errors))
772 return self.SerializePartialToString()
773 cls.SerializeToString = SerializeToString
774
775
776def _AddSerializePartialToStringMethod(message_descriptor, cls):
777 """Helper for _AddMessageMethods()."""
778 Encoder = encoder.Encoder
779
780 def SerializePartialToString(self):
781 encoder = Encoder()
782 # We need to serialize all extension and non-extension fields
783 # together, in sorted order by field number.
784 for field_descriptor, field_value in self.ListFields():
785 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
786 repeated_value = field_value
787 else:
788 repeated_value = [field_value]
789 for element in repeated_value:
790 _SerializeValueToEncoder(element, field_descriptor.number,
791 field_descriptor, encoder)
792 return encoder.ToString()
793 cls.SerializePartialToString = SerializePartialToString
794
795
796def _WireTypeForFieldType(field_type):
797 """Given a field type, returns the expected wire type."""
798 try:
799 return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type]
800 except KeyError:
801 raise message_mod.DecodeError('Unknown field type: %d' % field_type)
802
803
804def _RecursivelyMerge(field_number, field_type, decoder, message):
805 """Decodes a message from decoder into message.
806 message is either a group or a nested message within some containing
807 protocol message. If it's a group, we use the group protocol to
808 deserialize, and if it's a nested message, we use the nested-message
809 protocol.
810
811 Args:
812 field_number: The field number of message in its enclosing protocol buffer.
813 field_type: The field type of message. Must be either TYPE_MESSAGE
814 or TYPE_GROUP.
815 decoder: Decoder to read from.
816 message: Message to deserialize into.
817 """
818 if field_type == _FieldDescriptor.TYPE_MESSAGE:
819 decoder.ReadMessageInto(message)
820 elif field_type == _FieldDescriptor.TYPE_GROUP:
821 decoder.ReadGroupInto(field_number, message)
822 else:
823 raise message_mod.DecodeError('Unexpected field type: %d' % field_type)
824
825
826def _DeserializeScalarFromDecoder(field_type, decoder):
827 """Deserializes a scalar of the requested type from decoder. field_type must
828 be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant.
829 """
830 try:
831 method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type]
832 return method(decoder)
833 except KeyError:
834 raise message_mod.DecodeError('Unrecognized field type: %d' % field_type)
835
836
837def _SkipField(field_number, wire_type, decoder):
838 """Skips a field with the specified wire type.
839
840 Args:
841 field_number: Tag number of the field to skip.
842 wire_type: Wire type of the field to skip.
843 decoder: Decoder used to deserialize the messsage. It must be positioned
844 just after reading the the tag and wire type of the field.
845 """
846 if wire_type == wire_format.WIRETYPE_VARINT:
847 decoder.ReadUInt64()
848 elif wire_type == wire_format.WIRETYPE_FIXED64:
849 decoder.ReadFixed64()
850 elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
851 decoder.SkipBytes(decoder.ReadInt32())
852 elif wire_type == wire_format.WIRETYPE_START_GROUP:
853 _SkipGroup(field_number, decoder)
854 elif wire_type == wire_format.WIRETYPE_END_GROUP:
855 pass
856 elif wire_type == wire_format.WIRETYPE_FIXED32:
857 decoder.ReadFixed32()
858 else:
859 raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type)
860
861
862def _SkipGroup(group_number, decoder):
863 """Skips a nested group from the decoder.
864
865 Args:
866 group_number: Tag number of the group to skip.
867 decoder: Decoder used to deserialize the message. It must be positioned
868 exactly at the beginning of the message that should be skipped.
869 """
870 while True:
871 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
872 if (wire_type == wire_format.WIRETYPE_END_GROUP and
873 field_number == group_number):
874 return
875 _SkipField(field_number, wire_type, decoder)
876
877
878def _DeserializeMessageSetItem(message, decoder):
879 """Deserializes a message using the message set wire format.
880
881 Args:
882 message: Message to be parsed to.
883 decoder: The decoder to be used to deserialize encoded data. Note that the
884 decoder should be positioned just after reading the START_GROUP tag that
885 began the messageset item.
886 """
887 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
888 if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2:
889 raise message_mod.DecodeError(
890 'Incorrect message set wire format. '
891 'wire_type: %d, field_number: %d' % (wire_type, field_number))
892
893 type_id = decoder.ReadInt32()
894 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
895 if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3:
896 raise message_mod.DecodeError(
897 'Incorrect message set wire format. '
898 'wire_type: %d, field_number: %d' % (wire_type, field_number))
899
900 extension_dict = message.Extensions
901 extensions_by_number = extension_dict._AllExtensionsByNumber()
902 if type_id not in extensions_by_number:
903 _SkipField(field_number, wire_type, decoder)
904 return
905
906 field_descriptor = extensions_by_number[type_id]
907 value = extension_dict[field_descriptor]
908 decoder.ReadMessageInto(value)
909 # Read the END_GROUP tag.
910 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
911 if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1:
912 raise message_mod.DecodeError(
913 'Incorrect message set wire format. '
914 'wire_type: %d, field_number: %d' % (wire_type, field_number))
915
916
917def _DeserializeOneEntity(message_descriptor, message, decoder):
918 """Deserializes the next wire entity from decoder into message.
919 The next wire entity is either a scalar or a nested message,
920 and may also be an element in a repeated field (the wire encoding
921 is the same).
922
923 Args:
924 message_descriptor: A Descriptor instance describing all fields
925 in message.
926 message: The Message instance into which we're decoding our fields.
927 decoder: The Decoder we're using to deserialize encoded data.
928
929 Returns: The number of bytes read from decoder during this method.
930 """
931 initial_position = decoder.Position()
932 field_number, wire_type = decoder.ReadFieldNumberAndWireType()
933 extension_dict = message.Extensions
934 extensions_by_number = extension_dict._AllExtensionsByNumber()
935 if field_number in message_descriptor.fields_by_number:
936 # Non-extension field.
937 field_descriptor = message_descriptor.fields_by_number[field_number]
938 value = getattr(message, _PropertyName(field_descriptor.name))
939 def nonextension_setter_fn(scalar):
940 setattr(message, _PropertyName(field_descriptor.name), scalar)
941 scalar_setter_fn = nonextension_setter_fn
942 elif field_number in extensions_by_number:
943 # Extension field.
944 field_descriptor = extensions_by_number[field_number]
945 value = extension_dict[field_descriptor]
946 def extension_setter_fn(scalar):
947 extension_dict[field_descriptor] = scalar
948 scalar_setter_fn = extension_setter_fn
949 elif wire_type == wire_format.WIRETYPE_END_GROUP:
950 # We assume we're being parsed as the group that's ended.
951 return 0
952 elif (wire_type == wire_format.WIRETYPE_START_GROUP and
953 field_number == 1 and
954 message_descriptor.GetOptions().message_set_wire_format):
955 # A Message Set item.
956 _DeserializeMessageSetItem(message, decoder)
957 return decoder.Position() - initial_position
958 else:
959 _SkipField(field_number, wire_type, decoder)
960 return decoder.Position() - initial_position
961
962 # If we reach this point, we've identified the field as either
963 # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|,
964 # and |value| appropriately. Now actually deserialize the thing.
965 #
966 # field_descriptor: Describes the field we're deserializing.
967 # value: The value currently stored in the field to deserialize.
968 # Used only if the field is composite and/or repeated.
969 # scalar_setter_fn: A function F such that F(scalar) will
970 # set a nonrepeated scalar value for this field. Used only
971 # if this field is a nonrepeated scalar.
972
973 field_number = field_descriptor.number
974 field_type = field_descriptor.type
975 expected_wire_type = _WireTypeForFieldType(field_type)
976 if wire_type != expected_wire_type:
977 # Need to fill in uninterpreted_bytes. Work for the next CL.
978 raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
979
980 property_name = _PropertyName(field_descriptor.name)
981 label = field_descriptor.label
982 cpp_type = field_descriptor.cpp_type
983
984 # Nonrepeated scalar. Just set the field directly.
985 if (label != _FieldDescriptor.LABEL_REPEATED
986 and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
987 scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder))
988 return decoder.Position() - initial_position
989
990 # Nonrepeated composite. Recursively deserialize.
991 if label != _FieldDescriptor.LABEL_REPEATED:
992 composite = value
993 _RecursivelyMerge(field_number, field_type, decoder, composite)
994 return decoder.Position() - initial_position
995
996 # Now we know we're dealing with a repeated field of some kind.
997 element_list = value
998
999 if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
1000 # Repeated scalar.
1001 element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
1002 return decoder.Position() - initial_position
1003 else:
1004 # Repeated composite.
1005 composite = element_list.add()
1006 _RecursivelyMerge(field_number, field_type, decoder, composite)
1007 return decoder.Position() - initial_position
1008
1009
1010def _FieldOrExtensionValues(message, field_or_extension):
1011 """Retrieves the list of values for the specified field or extension.
1012
1013 The target field or extension can be optional, required or repeated, but it
1014 must have value(s) set. The assumption is that the target field or extension
1015 is set (e.g. _HasFieldOrExtension holds true).
1016
1017 Args:
1018 message: Message which contains the target field or extension.
1019 field_or_extension: Field or extension for which the list of values is
1020 required. Must be an instance of FieldDescriptor.
1021
1022 Returns:
1023 A list of values for the specified field or extension. This list will only
1024 contain a single element if the field is non-repeated.
1025 """
1026 if field_or_extension.is_extension:
1027 value = message.Extensions[field_or_extension]
1028 else:
1029 value = getattr(message, _ValueFieldName(field_or_extension.name))
1030 if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED:
1031 return [value]
1032 else:
1033 # In this case value is a list or repeated values.
1034 return value
1035
1036
1037def _HasFieldOrExtension(message, field_or_extension):
1038 """Checks if a message has the specified field or extension set.
1039
1040 The field or extension specified can be optional, required or repeated. If
1041 it is repeated, this function returns True. Otherwise it checks the has bit
1042 of the field or extension.
1043
1044 Args:
1045 message: Message which contains the target field or extension.
1046 field_or_extension: Field or extension to check. This must be a
1047 FieldDescriptor instance.
1048
1049 Returns:
1050 True if the message has a value set for the specified field or extension,
1051 or if the field or extension is repeated.
1052 """
1053 if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED:
1054 return True
1055 if field_or_extension.is_extension:
1056 return message.HasExtension(field_or_extension)
1057 else:
1058 return message.HasField(field_or_extension.name)
1059
1060
1061def _IsFieldOrExtensionInitialized(message, field, errors=None):
1062 """Checks if a message field or extension is initialized.
1063
1064 Args:
1065 message: The message which contains the field or extension.
1066 field: Field or extension to check. This must be a FieldDescriptor instance.
1067 errors: Errors will be appended to it, if set to a meaningful value.
1068
1069 Returns:
1070 True if the field/extension can be considered initialized.
1071 """
1072 # If the field is required and is not set, it isn't initialized.
1073 if field.label == _FieldDescriptor.LABEL_REQUIRED:
1074 if not _HasFieldOrExtension(message, field):
1075 if errors is not None:
1076 errors.append('Required field %s is not set.' % field.full_name)
1077 return False
1078
1079 # If the field is optional and is not set, or if it
1080 # isn't a submessage then the field is initialized.
1081 if field.label == _FieldDescriptor.LABEL_OPTIONAL:
1082 if not _HasFieldOrExtension(message, field):
1083 return True
1084 if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
1085 return True
1086
1087 # The field is set and is either a single or a repeated submessage.
1088 messages = _FieldOrExtensionValues(message, field)
1089 # If all submessages in this field are initialized, the field is
1090 # considered initialized.
1091 for message in messages:
1092 if not _InternalIsInitialized(message, errors):
1093 return False
1094 return True
1095
1096
1097def _InternalIsInitialized(message, errors=None):
1098 """Checks if all required fields of a message are set.
1099
1100 Args:
1101 message: The message to check.
1102 errors: If set, initialization errors will be appended to it.
1103
1104 Returns:
1105 True iff the specified message has all required fields set.
1106 """
1107 fields_and_extensions = []
1108 fields_and_extensions.extend(message.DESCRIPTOR.fields)
1109 fields_and_extensions.extend(
1110 [extension[0] for extension in message.Extensions._ListSetExtensions()])
1111 for field_or_extension in fields_and_extensions:
1112 if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors):
1113 return False
1114 return True
1115
1116
1117def _AddMergeFromStringMethod(message_descriptor, cls):
1118 """Helper for _AddMessageMethods()."""
1119 Decoder = decoder.Decoder
1120 def MergeFromString(self, serialized):
1121 decoder = Decoder(serialized)
1122 byte_count = 0
1123 while not decoder.EndOfStream():
1124 bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder)
1125 if not bytes_read:
1126 break
1127 byte_count += bytes_read
1128 return byte_count
1129 cls.MergeFromString = MergeFromString
1130
1131
1132def _AddIsInitializedMethod(cls):
1133 """Adds the IsInitialized method to the protocol message class."""
1134 cls.IsInitialized = _InternalIsInitialized
1135
1136
1137def _MergeFieldOrExtension(destination_msg, field, value):
1138 """Merges a specified message field into another message."""
1139 property_name = _PropertyName(field.name)
1140 is_extension = field.is_extension
1141
1142 if not is_extension:
1143 destination = getattr(destination_msg, property_name)
1144 elif (field.label == _FieldDescriptor.LABEL_REPEATED or
1145 field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1146 destination = destination_msg.Extensions[field]
1147
1148 # Case 1 - a composite field.
1149 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1150 if field.label == _FieldDescriptor.LABEL_REPEATED:
1151 for v in value:
1152 destination.add().MergeFrom(v)
1153 else:
1154 destination.MergeFrom(value)
1155 return
1156
1157 # Case 2 - a repeated field.
1158 if field.label == _FieldDescriptor.LABEL_REPEATED:
1159 for v in value:
1160 destination.append(v)
1161 return
1162
1163 # Case 3 - a singular field.
1164 if is_extension:
1165 destination_msg.Extensions[field] = value
1166 else:
1167 setattr(destination_msg, property_name, value)
1168
1169
1170def _AddMergeFromMethod(cls):
1171 def MergeFrom(self, msg):
1172 assert msg is not self
1173 for field in msg.ListFields():
1174 _MergeFieldOrExtension(self, field[0], field[1])
1175 cls.MergeFrom = MergeFrom
1176
1177
1178def _AddMessageMethods(message_descriptor, cls):
1179 """Adds implementations of all Message methods to cls."""
1180 _AddListFieldsMethod(message_descriptor, cls)
1181 _AddHasFieldMethod(cls)
1182 _AddClearFieldMethod(cls)
1183 _AddClearExtensionMethod(cls)
1184 _AddClearMethod(cls)
1185 _AddHasExtensionMethod(cls)
1186 _AddEqualsMethod(message_descriptor, cls)
1187 _AddSetListenerMethod(cls)
1188 _AddByteSizeMethod(message_descriptor, cls)
1189 _AddSerializeToStringMethod(message_descriptor, cls)
1190 _AddSerializePartialToStringMethod(message_descriptor, cls)
1191 _AddMergeFromStringMethod(message_descriptor, cls)
1192 _AddIsInitializedMethod(cls)
1193 _AddMergeFromMethod(cls)
1194
1195
1196def _AddPrivateHelperMethods(cls):
1197 """Adds implementation of private helper methods to cls."""
1198
1199 def MaybeCallTransitionToNonemptyCallback(self):
1200 """Calls self._listener.TransitionToNonempty() the first time this
1201 method is called. On all subsequent calls, this is a no-op.
1202 """
1203 if not self._called_transition_to_nonempty:
1204 self._listener.TransitionToNonempty()
1205 self._called_transition_to_nonempty = True
1206 cls._MaybeCallTransitionToNonemptyCallback = (
1207 MaybeCallTransitionToNonemptyCallback)
1208
1209 def MarkByteSizeDirty(self):
1210 """Sets the _cached_byte_size_dirty bit to true,
1211 and propagates this to our listener iff this was a state change.
1212 """
1213 if not self._cached_byte_size_dirty:
1214 self._cached_byte_size_dirty = True
1215 self._listener.ByteSizeDirty()
1216 cls._MarkByteSizeDirty = MarkByteSizeDirty
1217
1218
1219class _Listener(object):
1220
1221 """MessageListener implementation that a parent message registers with its
1222 child message.
1223
1224 In order to support semantics like:
1225
1226 foo.bar.baz = 23
1227 assert foo.HasField('bar')
1228
1229 ...child objects must have back references to their parents.
1230 This helper class is at the heart of this support.
1231 """
1232
1233 def __init__(self, parent_message, has_field_name):
1234 """Args:
1235 parent_message: The message whose _MaybeCallTransitionToNonemptyCallback()
1236 and _MarkByteSizeDirty() methods we should call when we receive
1237 TransitionToNonempty() and ByteSizeDirty() messages.
1238 has_field_name: The name of the "has" field that we should set in
1239 the parent message when we receive a TransitionToNonempty message,
1240 or None if there's no "has" field to set. (This will be the case
1241 for child objects in "repeated" fields).
1242 """
1243 # This listener establishes a back reference from a child (contained) object
1244 # to its parent (containing) object. We make this a weak reference to avoid
1245 # creating cyclic garbage when the client finishes with the 'parent' object
1246 # in the tree.
1247 if isinstance(parent_message, weakref.ProxyType):
1248 self._parent_message_weakref = parent_message
1249 else:
1250 self._parent_message_weakref = weakref.proxy(parent_message)
1251 self._has_field_name = has_field_name
1252
1253 def TransitionToNonempty(self):
1254 try:
1255 if self._has_field_name is not None:
1256 setattr(self._parent_message_weakref, self._has_field_name, True)
1257 # Propagate the signal to our parents iff this is the first field set.
1258 self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback()
1259 except ReferenceError:
1260 # We can get here if a client has kept a reference to a child object,
1261 # and is now setting a field on it, but the child's parent has been
1262 # garbage-collected. This is not an error.
1263 pass
1264
1265 def ByteSizeDirty(self):
1266 try:
1267 self._parent_message_weakref._MarkByteSizeDirty()
1268 except ReferenceError:
1269 # Same as above.
1270 pass
1271
1272
1273# TODO(robinson): Move elsewhere?
1274# TODO(robinson): Provide a clear() method here in addition to ClearField()?
1275class _RepeatedScalarFieldContainer(object):
1276
1277 """Simple, type-checked, list-like container for holding repeated scalars."""
1278
1279 # Minimizes memory usage and disallows assignment to other attributes.
1280 __slots__ = ['_message_listener', '_type_checker', '_values']
1281
1282 def __init__(self, message_listener, type_checker):
1283 """
1284 Args:
1285 message_listener: A MessageListener implementation.
1286 The _RepeatedScalarFieldContaininer will call this object's
1287 TransitionToNonempty() method when it transitions from being empty to
1288 being nonempty.
1289 type_checker: A _ValueChecker instance to run on elements inserted
1290 into this container.
1291 """
1292 self._message_listener = message_listener
1293 self._type_checker = type_checker
1294 self._values = []
1295
1296 def append(self, elem):
1297 self._type_checker.CheckValue(elem)
1298 self._values.append(elem)
1299 self._message_listener.ByteSizeDirty()
1300 if len(self._values) == 1:
1301 self._message_listener.TransitionToNonempty()
1302
1303 def remove(self, elem):
1304 self._values.remove(elem)
1305 self._message_listener.ByteSizeDirty()
1306
1307 # List-like __getitem__() support also makes us iterable (via "iter(foo)"
1308 # or implicitly via "for i in mylist:") for free.
1309 def __getitem__(self, key):
1310 return self._values[key]
1311
1312 def __setitem__(self, key, value):
1313 # No need to call TransitionToNonempty(), since if we're able to
1314 # set the element at this index, we were already nonempty before
1315 # this method was called.
1316 self._message_listener.ByteSizeDirty()
1317 self._type_checker.CheckValue(value)
1318 self._values[key] = value
1319
1320 def __len__(self):
1321 return len(self._values)
1322
1323 def __eq__(self, other):
1324 if self is other:
1325 return True
1326 # Special case for the same type which should be common and fast.
1327 if isinstance(other, self.__class__):
1328 return other._values == self._values
1329 # We are presumably comparing against some other sequence type.
1330 return other == self._values
1331
1332 def __ne__(self, other):
1333 # Can't use != here since it would infinitely recurse.
1334 return not self == other
1335
1336
1337# TODO(robinson): Move elsewhere?
1338# TODO(robinson): Provide a clear() method here in addition to ClearField()?
1339# TODO(robinson): Unify common functionality with
1340# _RepeatedScalarFieldContaininer?
1341class _RepeatedCompositeFieldContainer(object):
1342
1343 """Simple, list-like container for holding repeated composite fields."""
1344
1345 # Minimizes memory usage and disallows assignment to other attributes.
1346 __slots__ = ['_values', '_message_descriptor', '_message_listener']
1347
1348 def __init__(self, message_listener, message_descriptor):
1349 """Note that we pass in a descriptor instead of the generated directly,
1350 since at the time we construct a _RepeatedCompositeFieldContainer we
1351 haven't yet necessarily initialized the type that will be contained in the
1352 container.
1353
1354 Args:
1355 message_listener: A MessageListener implementation.
1356 The _RepeatedCompositeFieldContainer will call this object's
1357 TransitionToNonempty() method when it transitions from being empty to
1358 being nonempty.
1359 message_descriptor: A Descriptor instance describing the protocol type
1360 that should be present in this container. We'll use the
1361 _concrete_class field of this descriptor when the client calls add().
1362 """
1363 self._message_listener = message_listener
1364 self._message_descriptor = message_descriptor
1365 self._values = []
1366
1367 def add(self):
1368 new_element = self._message_descriptor._concrete_class()
1369 new_element._SetListener(self._message_listener)
1370 self._values.append(new_element)
1371 self._message_listener.ByteSizeDirty()
1372 self._message_listener.TransitionToNonempty()
1373 return new_element
1374
1375 def __delitem__(self, key):
1376 self._message_listener.ByteSizeDirty()
1377 del self._values[key]
1378
1379 # List-like __getitem__() support also makes us iterable (via "iter(foo)"
1380 # or implicitly via "for i in mylist:") for free.
1381 def __getitem__(self, key):
1382 return self._values[key]
1383
1384 def __len__(self):
1385 return len(self._values)
1386
1387 def __eq__(self, other):
1388 if self is other:
1389 return True
1390 if not isinstance(other, self.__class__):
1391 raise TypeError('Can only compare repeated composite fields against '
1392 'other repeated composite fields.')
1393 return self._values == other._values
1394
1395 def __ne__(self, other):
1396 # Can't use != here since it would infinitely recurse.
1397 return not self == other
1398
1399 # TODO(robinson): Implement, document, and test slicing support.
1400
1401
1402# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1403# TODO(robinson): Unify error handling of "unknown extension" crap.
1404# TODO(robinson): There's so much similarity between the way that
1405# extensions behave and the way that normal fields behave that it would
1406# be really nice to unify more code. It's not immediately obvious
1407# how to do this, though, and I'd rather get the full functionality
1408# implemented (and, crucially, get all the tests and specs fleshed out
1409# and passing), and then come back to this thorny unification problem.
1410# TODO(robinson): Support iteritems()-style iteration over all
1411# extensions with the "has" bits turned on?
1412class _ExtensionDict(object):
1413
1414 """Dict-like container for supporting an indexable "Extensions"
1415 field on proto instances.
1416
1417 Note that in all cases we expect extension handles to be
1418 FieldDescriptors.
1419 """
1420
1421 class _ExtensionListener(object):
1422
1423 """Adapts an _ExtensionDict to behave as a MessageListener."""
1424
1425 def __init__(self, extension_dict, handle_id):
1426 self._extension_dict = extension_dict
1427 self._handle_id = handle_id
1428
1429 def TransitionToNonempty(self):
1430 self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id)
1431
1432 def ByteSizeDirty(self):
1433 self._extension_dict._SubmessageByteSizeBecameDirty()
1434
1435 # TODO(robinson): Somewhere, we need to blow up if people
1436 # try to register two extensions with the same field number.
1437 # (And we need a test for this of course).
1438
1439 def __init__(self, extended_message, known_extensions):
1440 """extended_message: Message instance for which we are the Extensions dict.
1441 known_extensions: Iterable of known extension handles.
1442 These must be FieldDescriptors.
1443 """
1444 # We keep a weak reference to extended_message, since
1445 # it has a reference to this instance in turn.
1446 self._extended_message = weakref.proxy(extended_message)
1447 # We make a deep copy of known_extensions to avoid any
1448 # thread-safety concerns, since the argument passed in
1449 # is the global (class-level) dict of known extensions for
1450 # this type of message, which could be modified at any time
1451 # via a RegisterExtension() call.
1452 #
1453 # This dict maps from handle id to handle (a FieldDescriptor).
1454 #
1455 # XXX
1456 # TODO(robinson): This isn't good enough. The client could
1457 # instantiate an object in module A, then afterward import
1458 # module B and pass the instance to B.Foo(). If B imports
1459 # an extender of this proto and then tries to use it, B
1460 # will get a KeyError, even though the extension *is* registered
1461 # at the time of use.
1462 # XXX
1463 self._known_extensions = dict((id(e), e) for e in known_extensions)
1464 # Read lock around self._values, which may be modified by multiple
1465 # concurrent readers in the conceptually "const" __getitem__ method.
1466 # So, we grab this lock in every "read-only" method to ensure
1467 # that concurrent read access is safe without external locking.
1468 self._lock = threading.Lock()
1469 # Maps from extension handle ID to current value of that extension.
1470 self._values = {}
1471 # Maps from extension handle ID to a boolean "has" bit, but only
1472 # for non-repeated extension fields.
1473 keys = (id for id, extension in self._known_extensions.iteritems()
1474 if extension.label != _FieldDescriptor.LABEL_REPEATED)
1475 self._has_bits = dict.fromkeys(keys, False)
1476
1477 def __getitem__(self, extension_handle):
1478 """Returns the current value of the given extension handle."""
1479 # We don't care as much about keeping critical sections short in the
1480 # extension support, since it's presumably much less of a common case.
1481 self._lock.acquire()
1482 try:
1483 handle_id = id(extension_handle)
1484 if handle_id not in self._known_extensions:
1485 raise KeyError('Extension not known to this class')
1486 if handle_id not in self._values:
1487 self._AddMissingHandle(extension_handle, handle_id)
1488 return self._values[handle_id]
1489 finally:
1490 self._lock.release()
1491
1492 def __eq__(self, other):
1493 # We have to grab read locks since we're accessing _values
1494 # in a "const" method. See the comment in the constructor.
1495 if self is other:
1496 return True
1497 self._lock.acquire()
1498 try:
1499 other._lock.acquire()
1500 try:
1501 if self._has_bits != other._has_bits:
1502 return False
1503 # If there's a "has" bit, then only compare values where it is true.
1504 for k, v in self._values.iteritems():
1505 if self._has_bits.get(k, False) and v != other._values[k]:
1506 return False
1507 return True
1508 finally:
1509 other._lock.release()
1510 finally:
1511 self._lock.release()
1512
1513 def __ne__(self, other):
1514 return not self == other
1515
1516 # Note that this is only meaningful for non-repeated, scalar extension
1517 # fields. Note also that we may have to call
1518 # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field
1519 # this way, to set any necssary "has" bits in the ancestors of the extended
1520 # message.
1521 def __setitem__(self, extension_handle, value):
1522 """If extension_handle specifies a non-repeated, scalar extension
1523 field, sets the value of that field.
1524 """
1525 handle_id = id(extension_handle)
1526 if handle_id not in self._known_extensions:
1527 raise KeyError('Extension not known to this class')
1528 field = extension_handle # Just shorten the name.
1529 if (field.label == _FieldDescriptor.LABEL_OPTIONAL
1530 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
1531 # It's slightly wasteful to lookup the type checker each time,
1532 # but we expect this to be a vanishingly uncommon case anyway.
1533 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
1534 type_checker.CheckValue(value)
1535 self._values[handle_id] = value
1536 self._has_bits[handle_id] = True
1537 self._extended_message._MarkByteSizeDirty()
1538 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1539 else:
1540 raise TypeError('Extension is repeated and/or a composite type.')
1541
1542 def _AddMissingHandle(self, extension_handle, handle_id):
1543 """Helper internal to ExtensionDict."""
1544 # Special handling for non-repeated message extensions, which (like
1545 # normal fields of this kind) are initialized lazily.
1546 # REQUIRES: _lock already held.
1547 cpp_type = extension_handle.cpp_type
1548 label = extension_handle.label
1549 if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
1550 and label != _FieldDescriptor.LABEL_REPEATED):
1551 self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id)
1552 else:
1553 self._values[handle_id] = _DefaultValueForField(
1554 self._extended_message, extension_handle)
1555
1556 def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id):
1557 """Helper internal to ExtensionDict."""
1558 # REQUIRES: _lock already held.
1559 value = extension_handle.message_type._concrete_class()
1560 value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id))
1561 self._values[handle_id] = value
1562
1563 def _SubmessageTransitionedToNonempty(self, handle_id):
1564 """Called when a submessage with a given handle id first transitions to
1565 being nonempty. Called by _ExtensionListener.
1566 """
1567 assert handle_id in self._has_bits
1568 self._has_bits[handle_id] = True
1569 self._extended_message._MaybeCallTransitionToNonemptyCallback()
1570
1571 def _SubmessageByteSizeBecameDirty(self):
1572 """Called whenever a submessage's cached byte size becomes invalid
1573 (goes from being "clean" to being "dirty"). Called by _ExtensionListener.
1574 """
1575 self._extended_message._MarkByteSizeDirty()
1576
1577 # We may wish to widen the public interface of Message.Extensions
1578 # to expose some of this private functionality in the future.
1579 # For now, we make all this functionality module-private and just
1580 # implement what we need for serialization/deserialization,
1581 # HasField()/ClearField(), etc.
1582
1583 def _HasExtension(self, extension_handle):
1584 """Method for internal use by this module.
1585 Returns true iff we "have" this extension in the sense of the
1586 "has" bit being set.
1587 """
1588 handle_id = id(extension_handle)
1589 # Note that this is different from the other checks.
1590 if handle_id not in self._has_bits:
1591 raise KeyError('Extension not known to this class, or is repeated field.')
1592 return self._has_bits[handle_id]
1593
1594 # Intentionally pretty similar to ClearField() above.
1595 def _ClearExtension(self, extension_handle):
1596 """Method for internal use by this module.
1597 Clears the specified extension, unsetting its "has" bit.
1598 """
1599 handle_id = id(extension_handle)
1600 if handle_id not in self._known_extensions:
1601 raise KeyError('Extension not known to this class')
1602 default_value = _DefaultValueForField(self._extended_message,
1603 extension_handle)
1604 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1605 self._extended_message._MarkByteSizeDirty()
1606 else:
1607 cpp_type = extension_handle.cpp_type
1608 if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1609 if handle_id in self._values:
1610 # Future modifications to this object shouldn't set any
1611 # "has" bits here.
1612 self._values[handle_id]._SetListener(None)
1613 if self._has_bits[handle_id]:
1614 self._has_bits[handle_id] = False
1615 self._extended_message._MarkByteSizeDirty()
1616 if handle_id in self._values:
1617 del self._values[handle_id]
1618
1619 def _ListSetExtensions(self):
1620 """Method for internal use by this module.
1621
1622 Returns an sequence of all extensions that are currently "set"
1623 in this extension dict. A "set" extension is a repeated extension,
1624 or a non-repeated extension with its "has" bit set.
1625
1626 The returned sequence contains (field_descriptor, value) pairs,
1627 where value is the current value of the extension with the given
1628 field descriptor.
1629
1630 The sequence values are in arbitrary order.
1631 """
1632 self._lock.acquire() # Read-only methods must lock around self._values.
1633 try:
1634 set_extensions = []
1635 for handle_id, value in self._values.iteritems():
1636 handle = self._known_extensions[handle_id]
1637 if (handle.label == _FieldDescriptor.LABEL_REPEATED
1638 or self._has_bits[handle_id]):
1639 set_extensions.append((handle, value))
1640 return set_extensions
1641 finally:
1642 self._lock.release()
1643
1644 def _AllExtensionsByNumber(self):
1645 """Method for internal use by this module.
1646
1647 Returns: A dict mapping field_number to (handle, field_descriptor),
1648 for *all* registered extensions for this dict.
1649 """
1650 # TODO(robinson): Precompute and store this away. Note that we'll have to
1651 # be careful when we move away from having _known_extensions as a
1652 # deep-copied member of this object.
1653 return dict((f.number, f) for f in self._known_extensions.itervalues())