Martin Cosyns | f88ed6e | 2020-12-02 10:30:10 +0100 | [diff] [blame^] | 1 | # -*- coding:utf-8 -*- |
| 2 | |
| 3 | # copied from https://github.com/kaporzhu/protobuf-to-dict |
| 4 | # all credits to this script go to Kapor Zhu (kapor.zhu@gmail.com) |
| 5 | # |
| 6 | # need a fix for bug: "Use enum_label when setting the default value if use_enum_labels is true" (line 95) |
| 7 | |
| 8 | import base64 |
| 9 | |
| 10 | import six |
| 11 | |
| 12 | from google.protobuf.message import Message |
| 13 | from google.protobuf.descriptor import FieldDescriptor |
| 14 | |
| 15 | |
| 16 | __all__ = ["protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf", |
| 17 | "REVERSE_TYPE_CALLABLE_MAP"] |
| 18 | |
| 19 | |
| 20 | EXTENSION_CONTAINER = '___X' |
| 21 | |
| 22 | |
| 23 | TYPE_CALLABLE_MAP = { |
| 24 | FieldDescriptor.TYPE_DOUBLE: float, |
| 25 | FieldDescriptor.TYPE_FLOAT: float, |
| 26 | FieldDescriptor.TYPE_INT32: int, |
| 27 | FieldDescriptor.TYPE_INT64: int if six.PY3 else six.integer_types[1], |
| 28 | FieldDescriptor.TYPE_UINT32: int, |
| 29 | FieldDescriptor.TYPE_UINT64: int if six.PY3 else six.integer_types[1], |
| 30 | FieldDescriptor.TYPE_SINT32: int, |
| 31 | FieldDescriptor.TYPE_SINT64: int if six.PY3 else six.integer_types[1], |
| 32 | FieldDescriptor.TYPE_FIXED32: int, |
| 33 | FieldDescriptor.TYPE_FIXED64: int if six.PY3 else six.integer_types[1], |
| 34 | FieldDescriptor.TYPE_SFIXED32: int, |
| 35 | FieldDescriptor.TYPE_SFIXED64: int if six.PY3 else six.integer_types[1], |
| 36 | FieldDescriptor.TYPE_BOOL: bool, |
| 37 | FieldDescriptor.TYPE_STRING: six.text_type, |
| 38 | FieldDescriptor.TYPE_BYTES: six.binary_type, |
| 39 | FieldDescriptor.TYPE_ENUM: int, |
| 40 | } |
| 41 | |
| 42 | |
| 43 | def repeated(type_callable): |
| 44 | return lambda value_list: [type_callable(value) for value in value_list] |
| 45 | |
| 46 | |
| 47 | def enum_label_name(field, value): |
| 48 | return field.enum_type.values_by_number[int(value)].name |
| 49 | |
| 50 | |
| 51 | def _is_map_entry(field): |
| 52 | return (field.type == FieldDescriptor.TYPE_MESSAGE and |
| 53 | field.message_type.has_options and |
| 54 | field.message_type.GetOptions().map_entry) |
| 55 | |
| 56 | |
| 57 | def protobuf_to_dict(pb, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False, |
| 58 | including_default_value_fields=False): |
| 59 | result_dict = {} |
| 60 | extensions = {} |
| 61 | for field, value in pb.ListFields(): |
| 62 | if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry: |
| 63 | result_dict[field.name] = dict() |
| 64 | value_field = field.message_type.fields_by_name['value'] |
| 65 | type_callable = _get_field_value_adaptor( |
| 66 | pb, value_field, type_callable_map, |
| 67 | use_enum_labels, including_default_value_fields) |
| 68 | for k, v in value.items(): |
| 69 | result_dict[field.name][k] = type_callable(v) |
| 70 | continue |
| 71 | type_callable = _get_field_value_adaptor(pb, field, type_callable_map, |
| 72 | use_enum_labels, including_default_value_fields) |
| 73 | if field.label == FieldDescriptor.LABEL_REPEATED: |
| 74 | type_callable = repeated(type_callable) |
| 75 | |
| 76 | if field.is_extension: |
| 77 | extensions[str(field.number)] = type_callable(value) |
| 78 | continue |
| 79 | |
| 80 | result_dict[field.name] = type_callable(value) |
| 81 | |
| 82 | # Serialize default value if including_default_value_fields is True. |
| 83 | if including_default_value_fields: |
| 84 | for field in pb.DESCRIPTOR.fields: |
| 85 | # Singular message fields and oneof fields will not be affected. |
| 86 | if (( |
| 87 | field.label != FieldDescriptor.LABEL_REPEATED and |
| 88 | field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE) or |
| 89 | field.containing_oneof): |
| 90 | continue |
| 91 | if field.name in result_dict: |
| 92 | # Skip the field which has been serailized already. |
| 93 | continue |
| 94 | if _is_map_entry(field): |
| 95 | result_dict[field.name] = {} |
| 96 | else: |
| 97 | if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM: |
| 98 | result_dict[field.name] = enum_label_name(field, field.default_value) |
| 99 | else: |
| 100 | result_dict[field.name] = field.default_value |
| 101 | |
| 102 | if extensions: |
| 103 | result_dict[EXTENSION_CONTAINER] = extensions |
| 104 | return result_dict |
| 105 | |
| 106 | |
| 107 | def _get_field_value_adaptor(pb, field, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False, |
| 108 | including_default_value_fields=False): |
| 109 | if field.type == FieldDescriptor.TYPE_MESSAGE: |
| 110 | # recursively encode protobuf sub-message |
| 111 | return lambda pb: protobuf_to_dict( |
| 112 | pb, type_callable_map=type_callable_map, |
| 113 | use_enum_labels=use_enum_labels, |
| 114 | including_default_value_fields=including_default_value_fields, |
| 115 | ) |
| 116 | |
| 117 | if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM: |
| 118 | return lambda value: enum_label_name(field, value) |
| 119 | |
| 120 | if field.type in type_callable_map: |
| 121 | return type_callable_map[field.type] |
| 122 | |
| 123 | raise TypeError("Field %s.%s has unrecognised type id %d" % ( |
| 124 | pb.__class__.__name__, field.name, field.type)) |
| 125 | |
| 126 | |
| 127 | REVERSE_TYPE_CALLABLE_MAP = { |
| 128 | } |
| 129 | |
| 130 | |
| 131 | def dict_to_protobuf(pb_klass_or_instance, values, type_callable_map=REVERSE_TYPE_CALLABLE_MAP, strict=True, ignore_none=False): |
| 132 | """Populates a protobuf model from a dictionary. |
| 133 | |
| 134 | :param pb_klass_or_instance: a protobuf message class, or an protobuf instance |
| 135 | :type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message |
| 136 | :param dict values: a dictionary of values. Repeated and nested values are |
| 137 | fully supported. |
| 138 | :param dict type_callable_map: a mapping of protobuf types to callables for setting |
| 139 | values on the target instance. |
| 140 | :param bool strict: complain if keys in the map are not fields on the message. |
| 141 | :param bool strict: ignore None-values of fields, treat them as empty field |
| 142 | """ |
| 143 | if isinstance(pb_klass_or_instance, Message): |
| 144 | instance = pb_klass_or_instance |
| 145 | else: |
| 146 | instance = pb_klass_or_instance() |
| 147 | return _dict_to_protobuf(instance, values, type_callable_map, strict, ignore_none) |
| 148 | |
| 149 | |
| 150 | def _get_field_mapping(pb, dict_value, strict): |
| 151 | field_mapping = [] |
| 152 | for key, value in dict_value.items(): |
| 153 | if key == EXTENSION_CONTAINER: |
| 154 | continue |
| 155 | if key not in pb.DESCRIPTOR.fields_by_name: |
| 156 | if strict: |
| 157 | raise KeyError("%s does not have a field called %s" % (pb, key)) |
| 158 | continue |
| 159 | field_mapping.append((pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None))) |
| 160 | |
| 161 | for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items(): |
| 162 | try: |
| 163 | ext_num = int(ext_num) |
| 164 | except ValueError: |
| 165 | raise ValueError("Extension keys must be integers.") |
| 166 | if ext_num not in pb._extensions_by_number: |
| 167 | if strict: |
| 168 | raise KeyError("%s does not have a extension with number %s. Perhaps you forgot to import it?" % (pb, key)) |
| 169 | continue |
| 170 | ext_field = pb._extensions_by_number[ext_num] |
| 171 | pb_val = None |
| 172 | pb_val = pb.Extensions[ext_field] |
| 173 | field_mapping.append((ext_field, ext_val, pb_val)) |
| 174 | |
| 175 | return field_mapping |
| 176 | |
| 177 | |
| 178 | def _dict_to_protobuf(pb, value, type_callable_map, strict, ignore_none): |
| 179 | fields = _get_field_mapping(pb, value, strict) |
| 180 | |
| 181 | for field, input_value, pb_value in fields: |
| 182 | if ignore_none and input_value is None: |
| 183 | continue |
| 184 | if field.label == FieldDescriptor.LABEL_REPEATED: |
| 185 | if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry: |
| 186 | value_field = field.message_type.fields_by_name['value'] |
| 187 | for key, value in input_value.items(): |
| 188 | if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: |
| 189 | _dict_to_protobuf(getattr(pb, field.name)[key], value, type_callable_map, strict, ignore_none) |
| 190 | else: |
| 191 | getattr(pb, field.name)[key] = value |
| 192 | continue |
| 193 | for item in input_value: |
| 194 | if field.type == FieldDescriptor.TYPE_MESSAGE: |
| 195 | m = pb_value.add() |
| 196 | _dict_to_protobuf(m, item, type_callable_map, strict, ignore_none) |
| 197 | elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(item, six.string_types): |
| 198 | pb_value.append(_string_to_enum(field, item)) |
| 199 | else: |
| 200 | pb_value.append(item) |
| 201 | continue |
| 202 | if field.type == FieldDescriptor.TYPE_MESSAGE: |
| 203 | _dict_to_protobuf(pb_value, input_value, type_callable_map, strict, ignore_none) |
| 204 | continue |
| 205 | |
| 206 | if field.type in type_callable_map: |
| 207 | input_value = type_callable_map[field.type](input_value) |
| 208 | |
| 209 | if field.is_extension: |
| 210 | pb.Extensions[field] = input_value |
| 211 | continue |
| 212 | |
| 213 | if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value, six.string_types): |
| 214 | input_value = _string_to_enum(field, input_value) |
| 215 | |
| 216 | setattr(pb, field.name, input_value) |
| 217 | |
| 218 | return pb |
| 219 | |
| 220 | |
| 221 | def _string_to_enum(field, input_value): |
| 222 | enum_dict = field.enum_type.values_by_name |
| 223 | try: |
| 224 | input_value = enum_dict[input_value].number |
| 225 | except KeyError: |
| 226 | raise KeyError("`%s` is not a valid value for field `%s`" % (input_value, field.name)) |
| 227 | return input_value |