blob: bf4538ad8add41b22cacc87cda27f76ac6919ebc [file] [log] [blame]
Martin Cosynsf88ed6e2020-12-02 10:30:10 +01001# -*- coding:utf-8 -*-
2
3# copied from https://github.com/kaporzhu/protobuf-to-dict
4# all credits to this script go to Kapor Zhu (kapor.zhu@gmail.com)
5#
6# need a fix for bug: "Use enum_label when setting the default value if use_enum_labels is true" (line 95)
7
8import base64
9
10import six
11
12from google.protobuf.message import Message
13from google.protobuf.descriptor import FieldDescriptor
14
15
16__all__ = ["protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf",
17 "REVERSE_TYPE_CALLABLE_MAP"]
18
19
20EXTENSION_CONTAINER = '___X'
21
22
23TYPE_CALLABLE_MAP = {
24 FieldDescriptor.TYPE_DOUBLE: float,
25 FieldDescriptor.TYPE_FLOAT: float,
26 FieldDescriptor.TYPE_INT32: int,
27 FieldDescriptor.TYPE_INT64: int if six.PY3 else six.integer_types[1],
28 FieldDescriptor.TYPE_UINT32: int,
29 FieldDescriptor.TYPE_UINT64: int if six.PY3 else six.integer_types[1],
30 FieldDescriptor.TYPE_SINT32: int,
31 FieldDescriptor.TYPE_SINT64: int if six.PY3 else six.integer_types[1],
32 FieldDescriptor.TYPE_FIXED32: int,
33 FieldDescriptor.TYPE_FIXED64: int if six.PY3 else six.integer_types[1],
34 FieldDescriptor.TYPE_SFIXED32: int,
35 FieldDescriptor.TYPE_SFIXED64: int if six.PY3 else six.integer_types[1],
36 FieldDescriptor.TYPE_BOOL: bool,
37 FieldDescriptor.TYPE_STRING: six.text_type,
38 FieldDescriptor.TYPE_BYTES: six.binary_type,
39 FieldDescriptor.TYPE_ENUM: int,
40}
41
42
43def repeated(type_callable):
44 return lambda value_list: [type_callable(value) for value in value_list]
45
46
47def enum_label_name(field, value):
48 return field.enum_type.values_by_number[int(value)].name
49
50
51def _is_map_entry(field):
52 return (field.type == FieldDescriptor.TYPE_MESSAGE and
53 field.message_type.has_options and
54 field.message_type.GetOptions().map_entry)
55
56
57def protobuf_to_dict(pb, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False,
58 including_default_value_fields=False):
59 result_dict = {}
60 extensions = {}
61 for field, value in pb.ListFields():
62 if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry:
63 result_dict[field.name] = dict()
64 value_field = field.message_type.fields_by_name['value']
65 type_callable = _get_field_value_adaptor(
66 pb, value_field, type_callable_map,
67 use_enum_labels, including_default_value_fields)
68 for k, v in value.items():
69 result_dict[field.name][k] = type_callable(v)
70 continue
71 type_callable = _get_field_value_adaptor(pb, field, type_callable_map,
72 use_enum_labels, including_default_value_fields)
73 if field.label == FieldDescriptor.LABEL_REPEATED:
74 type_callable = repeated(type_callable)
75
76 if field.is_extension:
77 extensions[str(field.number)] = type_callable(value)
78 continue
79
80 result_dict[field.name] = type_callable(value)
81
82 # Serialize default value if including_default_value_fields is True.
83 if including_default_value_fields:
84 for field in pb.DESCRIPTOR.fields:
85 # Singular message fields and oneof fields will not be affected.
86 if ((
87 field.label != FieldDescriptor.LABEL_REPEATED and
88 field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE) or
89 field.containing_oneof):
90 continue
91 if field.name in result_dict:
92 # Skip the field which has been serailized already.
93 continue
94 if _is_map_entry(field):
95 result_dict[field.name] = {}
96 else:
97 if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM:
98 result_dict[field.name] = enum_label_name(field, field.default_value)
99 else:
100 result_dict[field.name] = field.default_value
101
102 if extensions:
103 result_dict[EXTENSION_CONTAINER] = extensions
104 return result_dict
105
106
107def _get_field_value_adaptor(pb, field, type_callable_map=TYPE_CALLABLE_MAP, use_enum_labels=False,
108 including_default_value_fields=False):
109 if field.type == FieldDescriptor.TYPE_MESSAGE:
110 # recursively encode protobuf sub-message
111 return lambda pb: protobuf_to_dict(
112 pb, type_callable_map=type_callable_map,
113 use_enum_labels=use_enum_labels,
114 including_default_value_fields=including_default_value_fields,
115 )
116
117 if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM:
118 return lambda value: enum_label_name(field, value)
119
120 if field.type in type_callable_map:
121 return type_callable_map[field.type]
122
123 raise TypeError("Field %s.%s has unrecognised type id %d" % (
124 pb.__class__.__name__, field.name, field.type))
125
126
127REVERSE_TYPE_CALLABLE_MAP = {
128}
129
130
131def dict_to_protobuf(pb_klass_or_instance, values, type_callable_map=REVERSE_TYPE_CALLABLE_MAP, strict=True, ignore_none=False):
132 """Populates a protobuf model from a dictionary.
133
134 :param pb_klass_or_instance: a protobuf message class, or an protobuf instance
135 :type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message
136 :param dict values: a dictionary of values. Repeated and nested values are
137 fully supported.
138 :param dict type_callable_map: a mapping of protobuf types to callables for setting
139 values on the target instance.
140 :param bool strict: complain if keys in the map are not fields on the message.
141 :param bool strict: ignore None-values of fields, treat them as empty field
142 """
143 if isinstance(pb_klass_or_instance, Message):
144 instance = pb_klass_or_instance
145 else:
146 instance = pb_klass_or_instance()
147 return _dict_to_protobuf(instance, values, type_callable_map, strict, ignore_none)
148
149
150def _get_field_mapping(pb, dict_value, strict):
151 field_mapping = []
152 for key, value in dict_value.items():
153 if key == EXTENSION_CONTAINER:
154 continue
155 if key not in pb.DESCRIPTOR.fields_by_name:
156 if strict:
157 raise KeyError("%s does not have a field called %s" % (pb, key))
158 continue
159 field_mapping.append((pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None)))
160
161 for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items():
162 try:
163 ext_num = int(ext_num)
164 except ValueError:
165 raise ValueError("Extension keys must be integers.")
166 if ext_num not in pb._extensions_by_number:
167 if strict:
168 raise KeyError("%s does not have a extension with number %s. Perhaps you forgot to import it?" % (pb, key))
169 continue
170 ext_field = pb._extensions_by_number[ext_num]
171 pb_val = None
172 pb_val = pb.Extensions[ext_field]
173 field_mapping.append((ext_field, ext_val, pb_val))
174
175 return field_mapping
176
177
178def _dict_to_protobuf(pb, value, type_callable_map, strict, ignore_none):
179 fields = _get_field_mapping(pb, value, strict)
180
181 for field, input_value, pb_value in fields:
182 if ignore_none and input_value is None:
183 continue
184 if field.label == FieldDescriptor.LABEL_REPEATED:
185 if field.message_type and field.message_type.has_options and field.message_type.GetOptions().map_entry:
186 value_field = field.message_type.fields_by_name['value']
187 for key, value in input_value.items():
188 if value_field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
189 _dict_to_protobuf(getattr(pb, field.name)[key], value, type_callable_map, strict, ignore_none)
190 else:
191 getattr(pb, field.name)[key] = value
192 continue
193 for item in input_value:
194 if field.type == FieldDescriptor.TYPE_MESSAGE:
195 m = pb_value.add()
196 _dict_to_protobuf(m, item, type_callable_map, strict, ignore_none)
197 elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(item, six.string_types):
198 pb_value.append(_string_to_enum(field, item))
199 else:
200 pb_value.append(item)
201 continue
202 if field.type == FieldDescriptor.TYPE_MESSAGE:
203 _dict_to_protobuf(pb_value, input_value, type_callable_map, strict, ignore_none)
204 continue
205
206 if field.type in type_callable_map:
207 input_value = type_callable_map[field.type](input_value)
208
209 if field.is_extension:
210 pb.Extensions[field] = input_value
211 continue
212
213 if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value, six.string_types):
214 input_value = _string_to_enum(field, input_value)
215
216 setattr(pb, field.name, input_value)
217
218 return pb
219
220
221def _string_to_enum(field, input_value):
222 enum_dict = field.enum_type.values_by_name
223 try:
224 input_value = enum_dict[input_value].number
225 except KeyError:
226 raise KeyError("`%s` is not a valid value for field `%s`" % (input_value, field.name))
227 return input_value