blob: d93c8dad7b2cbfa4ce38e9ff8dc271ef486bab36 [file] [log] [blame]
Zsolt Haraszti46c72002016-10-10 09:55:30 -07001#
2# Copyright 2016 the original author or authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16import re
17from collections import OrderedDict
18from copy import copy
19
20from google.protobuf.descriptor import FieldDescriptor
21
22re_path_param = re.compile(r'/{([^{]+)}')
23re_segment = re.compile(r'/(?P<absolute>[^{}/]+)|(?P<symbolic>{[^}]+})')
24
25
26class DuplicateMethodAndPathError(Exception): pass
27class ProtobufCompilationFailedError(Exception): pass
28class InvalidPathArgumentError(Exception): pass
29
30
31def native_descriptors_to_swagger(native_descriptors):
32 """
33 Generate a swagger data dict from the native descriptors extracted
34 from protobuf file(s).
35 :param native_descriptors:
36 Dict as extracted from proto file descriptors.
37 See DescriptorParser and its parse_file_descriptors() method.
38 :return: dict ready to be serialized to JSON as swagger.json file.
39 """
40
41 # gather all top-level and nested message type definitions and build map
42 message_types_dict = gather_all_message_types(native_descriptors)
43 message_type_names = set(message_types_dict.iterkeys())
44
45 # create similar map for all top-level and nested enum definitions
46 enum_types_dict = gather_all_enum_types(native_descriptors)
47 enum_type_names = set(enum_types_dict.iterkeys())
48
49 # make sure none clashes and generate set of all names (for sanity checks)
50 assert not message_type_names.intersection(enum_type_names)
51 all_type_names = message_type_names.union(enum_type_names)
52 all_types = {}
53 all_types.update(message_types_dict)
54 all_types.update(enum_types_dict)
55
56 # gather all method definitions and collect all referenced input/output
57 # types
58 types_referenced, methods_dict = gather_all_methods(native_descriptors)
59
60 # process all directly and indirectly referenced types into JSON schema
61 # type definitions
62 definitions = generate_definitions(types_referenced, all_types)
63
64 # process all method and generate the swagger path entries
65 paths = generate_paths(methods_dict, definitions)
66
67 # static part
68 # last descriptor is assumed to be the top-most one
69 root_descriptor = native_descriptors[-1]
70 swagger = {
71 'swagger': "2.0",
72 'info': {
73 'title': root_descriptor['name'],
74 'version': "version not set"
75 },
76 'schemes': ["http", "https"],
77 'consumes': ["application/json"],
78 'produces': ["application/json"],
79 'paths': paths,
80 'definitions': definitions
81 }
82
83 return swagger
84
85
86def gather_all_message_types(descriptors):
87 return dict(
88 (full_name, message_type)
89 for full_name, message_type
90 in iterate_message_types(descriptors)
91 )
92
93
94def gather_all_enum_types(descriptors):
95 return dict(
96 (full_name, enum_type)
97 for full_name, enum_type
98 in iterate_enum_types(descriptors)
99 )
100
101
102def gather_all_methods(descriptors):
103 types_referenced = set()
104 methods = OrderedDict()
105 for full_name, service, method in iterate_methods(descriptors):
106 methods[full_name] = (service, method)
107 types_referenced.add(method['input_type'].strip('.'))
108 types_referenced.add(method['output_type'].strip('.'))
109 return types_referenced, methods
110
111
112def iterate_methods(descriptors):
113 for descriptor in descriptors:
114 package = descriptor['package']
115 for service in descriptor.get('service', []):
116 service_prefix = package + '.' + service['name']
117 for method in service.get('method', []):
118 # skip methods that do not have http options
119 options = method['options']
120 if options.has_key('http'):
121 full_name = service_prefix + '.' + method['name']
122 yield full_name, service, method
123
124
125def iterate_for_type_in(message_types, prefix):
126 for message_type in message_types:
127 full_name = prefix + '.' + message_type['name']
128 yield full_name, message_type
129 for nested_full_name, nested in iterate_for_type_in(
130 message_type.get('nested_type', []), full_name):
131 yield nested_full_name, nested
132
133
134def iterate_message_types(descriptors):
135 for descriptor in descriptors:
136 package = descriptor['package']
137 top_types = descriptor.get('message_type', [])
138 for full_name, message_type in iterate_for_type_in(top_types, package):
139 yield full_name, message_type
140
141
142def iterate_enum_types(descriptors):
143 for descriptor in descriptors:
144 package = descriptor['package']
145 for enum in descriptor.get('enum_type', []):
146 enum_full_name = package + '.' + enum['name']
147 yield enum_full_name, enum
148 top_types = descriptor.get('message_type', [])
149 for full_name, message_type in iterate_for_type_in(top_types, package):
150 for enum in message_type.get('enum_type', []):
151 enum_full_name = full_name + '.' + enum['name']
152 yield enum_full_name, enum
153
154
155def generate_definitions(types_referenced, types):
156 """Walk all the referenced types and for each, generate a JSON schema
157 definition. These may also refer to other types, so keep the needed
158 set up-to-date.
159 """
160 definitions = {}
161 wanted = copy(types_referenced)
162 while wanted:
163 full_name = wanted.pop()
164 type = types[full_name]
165 definition, types_referenced = make_definition(type, types)
166 definitions[full_name] = definition
167 for type_referenced in types_referenced:
168 if not definitions.has_key(type_referenced):
169 wanted.add(type_referenced)
170 return definitions
171
172
173def make_definition(type, types):
174 if type['_type'] == 'google.protobuf.EnumDescriptorProto':
175 return make_enum_definition(type), set()
176 else:
177 return make_object_definition(type, types)
178
179
180def make_enum_definition(type):
181
182 def make_value_desc(enum_value):
183 txt = ' - {}'.format(enum_value['name'])
184 description = enum_value.get('_description', '')
185 if description:
186 txt += ': {}'.format(description)
187 return txt
188
189 string_values = [v['name'] for v in type['value']]
190 default = type['value'][0]['name']
191 description = (
192 (type.get('_description', '') or type['name'])
193 + '\nValid values:\n'
194 + '\n'.join(make_value_desc(v) for v in type['value'])
195 )
196
197 definition = {
198 'type': 'string',
199 'enum': string_values,
200 'default': default,
201 'description': description
202 }
203
204 return definition
205
206
207def make_object_definition(type, types):
208
209 definition = {
210 'type': 'object'
211 }
212
213 referenced = set()
214 properties = {}
215 for field in type.get('field', []):
216 field_name, property, referenced_by_field = make_property(field, types)
217 properties[field_name] = property
218 referenced.update(referenced_by_field)
219
220 if properties:
221 definition['properties'] = properties
222
223 if type.has_key('_description'):
224 definition['description'] = type['_description']
225
226 return definition, referenced
227
228
229def make_property(field, types):
230
231 referenced = set()
232
233 repeated = field['label'] == FieldDescriptor.LABEL_REPEATED
234
235 def check_if_map_entry(type_name):
236 type = types[type_name]
237 if type.get('options', {}).get('map_entry', False):
238 _, property, __ = make_property(type['field'][1], types)
239 return property
240
241 if field['type'] == FieldDescriptor.TYPE_MESSAGE:
242
243 type_name = field['type_name'].strip('.')
244
245 maybe_map_value_type = check_if_map_entry(type_name)
246 if maybe_map_value_type:
247 # map-entries are inlined
248 repeated = False
249 property = {
250 'type': 'object',
251 'additionalProperties': maybe_map_value_type
252 }
253
254 elif type_name == 'google.protobuf.Timestamp':
255 # time-stamp is mapped back to JSON schema date-time string
256 property = {
257 'type': 'string',
258 'format': 'date-time'
259 }
260
261 else:
262 # normal nested object field
263 property = {
264 '$ref': '#/definitions/{}'.format(type_name)
265 }
266 referenced.add(type_name)
267
268 elif field['type'] == FieldDescriptor.TYPE_ENUM:
269 type_name = field['type_name'].strip('.')
270 property = {
271 '$ref': '#/definitions/{}'.format(type_name)
272 }
273 referenced.add(type_name)
274
275 elif field['type'] == FieldDescriptor.TYPE_GROUP:
276 raise NotImplementedError()
277
278 else:
279 _type, format = TYPE_MAP[field['type']]
280 property = {
281 'type': _type,
282 'format': format
283 }
284
285 if repeated:
286 property = {
287 'type': 'array',
288 'items': property
289 }
290
291 if field.has_key('_description'):
292 property['description'] = field['_description']
293
294 return field['name'], property, referenced
295
296
297def generate_paths(methods_dict, definitions):
298
299 paths = {}
300
301 def _iterate():
302 for full_name, (service, method) in methods_dict.iteritems():
303 http_option = method['options']['http']
304 yield service, method, http_option
305 for binding in http_option.get('additional_bindings', []):
306 yield service, method, binding
307
308 def prune_path(path):
309 """rid '=<stuff>' pattern from path symbolic segments"""
310 segments = re_segment.findall(path)
311 pruned_segments = []
312 for absolute, symbolic in segments:
313 if symbolic:
314 full_symbol = symbolic[1:-1]
315 pruned_symbol = full_symbol.split('=', 2)[0]
316 pruned_segments.append('{' + pruned_symbol + '}')
317 else:
318 pruned_segments.append(absolute)
319
320 return '/' + '/'.join(pruned_segments)
321
322 def lookup_input_type(input_type_name):
323 return definitions[input_type_name.strip('.')]
324
325 def lookup_type(input_type, field_name):
326 local_field_name, _, rest = field_name.partition('.')
327 properties = input_type['properties']
328 if not properties.has_key(local_field_name):
329 raise InvalidPathArgumentError(
330 'Input type has no field {}'.format(field_name))
331 field = properties[local_field_name]
332 if rest:
333 field_type = field.get('type', 'object')
334 assert field_type == 'object', (
335 'Nested field name "%s" refers to field that of type "%s" '
336 '(.%s should be nested object field)'
337 % (field_name, field_type, local_field_name))
338 ref = field['$ref']
339 assert ref.startswith('#/definitions/')
340 type_name = ref.replace('#/definitions/', '')
341 nested_input_type = lookup_input_type(type_name)
342 return lookup_type(nested_input_type, rest)
343 else:
344 return field['type'], field['format']
345
346 def make_entry(service, method, http):
347 parameters = []
348 verb = None
349 for verb_candidate in ('get', 'delete', 'patch', 'post', 'put'):
350 if verb_candidate in http:
351 verb, path = verb_candidate, http[verb_candidate]
352 break
353 if 'custom' in http:
354 assert verb is None
355 verb = http['custom']['kind']
356 path = http['custom']['path']
357 assert verb is not None
358 path = prune_path(path)
359
360 # for each symbolic segment in path, add a path parameter entry
361 input_type = lookup_input_type(method['input_type'])
362 for segment in re_path_param.findall(path):
363 symbol = segment.split('=')[0]
364 _type, format = lookup_type(input_type, symbol)
365 parameters.append({
366 'in': 'path',
367 'name': symbol,
368 'required': True,
369 'type': _type,
370 'format': format
371 })
372
373 if 'body' in http:
374 if 'body' in http: # TODO validate if body lists fields
375 parameters.append({
376 'in': 'body',
377 'name': 'body',
378 'required': True,
379 'schema': {'$ref': '#/definitions/{}'.format(
380 method['input_type'].strip('.'))}
381 })
382
383 entry = {
384 'operationId': method['name'],
385 'tags': [service['name'],],
386 'responses': {
387 '200': { # TODO: code is 201 and 209 in POST/DELETE?
388 'description': unicode(""), # TODO: ever filled by proto?
389 'schema': {
390 '$ref': '#/definitions/{}'.format(
391 method['output_type'].strip('.'))
392 }
393 },
394 # TODO shall we prefill with standard error (verb specific),
395 # such as 400, 403, 404, 409, 509, 500, 503 etc.
396 }
397 }
398
399 if parameters:
400 entry['parameters'] = parameters
401
402 summary, description = extract_summary_and_description(method)
403 if summary:
404 entry['summary'] = summary
405 if description:
406 entry['description'] = description
407
408 return path, verb, entry
409
410 for service, method, http in _iterate():
411 path, verb, entry = make_entry(service, method, http)
412 path_dict = paths.setdefault(path, {})
413 if verb in path_dict:
414 raise DuplicateMethodAndPathError(
415 'There is already a {} method defined for path ({})'.format(
416 verb, path))
417 path_dict[verb] = entry
418
419 return paths
420
421
422def extract_summary_and_description(obj):
423 """
424 Break raw _description field (if present) into a summary line and/or
425 detailed description text as follows:
426 * if text is a single line (not counting white-spaces), then it is a
427 summary and there is no detailed description.
428 * if text starts with a non-empty line followied by an empty line followed
429 by at least one non-empty line, that the 1s line is the summary and the
430 lines after the empty line is the description.
431 * in all other cases the text is considered a description and no summary
432 is generated.
433 """
434 assert isinstance(obj, dict)
435 summary, description = None, None
436 text = obj.get('_description', '')
437 if text:
438 s, blank, d = (text.split('\n', 2) + ['', ''])[:3] # so we can demux
439 if not blank.strip():
440 summary = s
441 if d.strip():
442 description = d
443 else:
444 description = text
445
446 return summary, description
447
448
449TYPE_MAP = {
450 FieldDescriptor.TYPE_BOOL: ('boolean', 'boolean'),
451 FieldDescriptor.TYPE_BYTES: ('string', 'byte'),
452 FieldDescriptor.TYPE_DOUBLE: ('number', 'double'),
453 FieldDescriptor.TYPE_ENUM: ('string', 'string'),
454 FieldDescriptor.TYPE_FIXED32: ('integer', 'int64'),
455 FieldDescriptor.TYPE_FIXED64: ('string', 'uint64'),
456 FieldDescriptor.TYPE_FLOAT: ('number', 'float'),
457 FieldDescriptor.TYPE_INT32: ('integer', 'int32'),
458 FieldDescriptor.TYPE_INT64: ('string', 'int64'),
459 FieldDescriptor.TYPE_SFIXED32: ('integer', 'int32'),
460 FieldDescriptor.TYPE_SFIXED64: ('string', 'int64'),
461 FieldDescriptor.TYPE_STRING: ('string', 'string'),
462 FieldDescriptor.TYPE_SINT32: ('integer', 'int32'),
463 FieldDescriptor.TYPE_SINT64: ('string', 'int64'),
464 FieldDescriptor.TYPE_UINT32: ('integer', 'int64'),
465 FieldDescriptor.TYPE_UINT64: ('string', 'uint64'),
466 # FieldDescriptor.TYPE_MESSAGE:
467 # FieldDescriptor.TYPE_GROUP:
468}