Zsolt Haraszti | bae1275 | 2016-10-10 09:55:30 -0700 | [diff] [blame] | 1 | # |
| 2 | # Copyright 2016 the original author or authors. |
| 3 | # |
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | # you may not use this file except in compliance with the License. |
| 6 | # You may obtain a copy of the License at |
| 7 | # |
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | # |
| 10 | # Unless required by applicable law or agreed to in writing, software |
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | # See the License for the specific language governing permissions and |
| 14 | # limitations under the License. |
| 15 | # |
| 16 | import re |
| 17 | from collections import OrderedDict |
| 18 | from copy import copy |
| 19 | |
| 20 | from google.protobuf.descriptor import FieldDescriptor |
| 21 | |
| 22 | re_path_param = re.compile(r'/{([^{]+)}') |
| 23 | re_segment = re.compile(r'/(?P<absolute>[^{}/]+)|(?P<symbolic>{[^}]+})') |
| 24 | |
| 25 | |
| 26 | class DuplicateMethodAndPathError(Exception): pass |
| 27 | class ProtobufCompilationFailedError(Exception): pass |
| 28 | class InvalidPathArgumentError(Exception): pass |
| 29 | |
| 30 | |
| 31 | def native_descriptors_to_swagger(native_descriptors): |
| 32 | """ |
| 33 | Generate a swagger data dict from the native descriptors extracted |
| 34 | from protobuf file(s). |
| 35 | :param native_descriptors: |
| 36 | Dict as extracted from proto file descriptors. |
| 37 | See DescriptorParser and its parse_file_descriptors() method. |
| 38 | :return: dict ready to be serialized to JSON as swagger.json file. |
| 39 | """ |
| 40 | |
| 41 | # gather all top-level and nested message type definitions and build map |
| 42 | message_types_dict = gather_all_message_types(native_descriptors) |
| 43 | message_type_names = set(message_types_dict.iterkeys()) |
| 44 | |
| 45 | # create similar map for all top-level and nested enum definitions |
| 46 | enum_types_dict = gather_all_enum_types(native_descriptors) |
| 47 | enum_type_names = set(enum_types_dict.iterkeys()) |
| 48 | |
| 49 | # make sure none clashes and generate set of all names (for sanity checks) |
| 50 | assert not message_type_names.intersection(enum_type_names) |
| 51 | all_type_names = message_type_names.union(enum_type_names) |
| 52 | all_types = {} |
| 53 | all_types.update(message_types_dict) |
| 54 | all_types.update(enum_types_dict) |
| 55 | |
| 56 | # gather all method definitions and collect all referenced input/output |
| 57 | # types |
| 58 | types_referenced, methods_dict = gather_all_methods(native_descriptors) |
| 59 | |
| 60 | # process all directly and indirectly referenced types into JSON schema |
| 61 | # type definitions |
| 62 | definitions = generate_definitions(types_referenced, all_types) |
| 63 | |
| 64 | # process all method and generate the swagger path entries |
| 65 | paths = generate_paths(methods_dict, definitions) |
| 66 | |
| 67 | # static part |
| 68 | # last descriptor is assumed to be the top-most one |
| 69 | root_descriptor = native_descriptors[-1] |
| 70 | swagger = { |
| 71 | 'swagger': "2.0", |
| 72 | 'info': { |
| 73 | 'title': root_descriptor['name'], |
| 74 | 'version': "version not set" |
| 75 | }, |
| 76 | 'schemes': ["http", "https"], |
| 77 | 'consumes': ["application/json"], |
| 78 | 'produces': ["application/json"], |
| 79 | 'paths': paths, |
| 80 | 'definitions': definitions |
| 81 | } |
| 82 | |
| 83 | return swagger |
| 84 | |
| 85 | |
| 86 | def gather_all_message_types(descriptors): |
| 87 | return dict( |
| 88 | (full_name, message_type) |
| 89 | for full_name, message_type |
| 90 | in iterate_message_types(descriptors) |
| 91 | ) |
| 92 | |
| 93 | |
| 94 | def gather_all_enum_types(descriptors): |
| 95 | return dict( |
| 96 | (full_name, enum_type) |
| 97 | for full_name, enum_type |
| 98 | in iterate_enum_types(descriptors) |
| 99 | ) |
| 100 | |
| 101 | |
| 102 | def gather_all_methods(descriptors): |
| 103 | types_referenced = set() |
| 104 | methods = OrderedDict() |
| 105 | for full_name, service, method in iterate_methods(descriptors): |
| 106 | methods[full_name] = (service, method) |
| 107 | types_referenced.add(method['input_type'].strip('.')) |
| 108 | types_referenced.add(method['output_type'].strip('.')) |
| 109 | return types_referenced, methods |
| 110 | |
| 111 | |
| 112 | def iterate_methods(descriptors): |
| 113 | for descriptor in descriptors: |
| 114 | package = descriptor['package'] |
| 115 | for service in descriptor.get('service', []): |
| 116 | service_prefix = package + '.' + service['name'] |
| 117 | for method in service.get('method', []): |
| 118 | # skip methods that do not have http options |
| 119 | options = method['options'] |
| 120 | if options.has_key('http'): |
| 121 | full_name = service_prefix + '.' + method['name'] |
| 122 | yield full_name, service, method |
| 123 | |
| 124 | |
| 125 | def iterate_for_type_in(message_types, prefix): |
| 126 | for message_type in message_types: |
| 127 | full_name = prefix + '.' + message_type['name'] |
| 128 | yield full_name, message_type |
| 129 | for nested_full_name, nested in iterate_for_type_in( |
| 130 | message_type.get('nested_type', []), full_name): |
| 131 | yield nested_full_name, nested |
| 132 | |
| 133 | |
| 134 | def iterate_message_types(descriptors): |
| 135 | for descriptor in descriptors: |
| 136 | package = descriptor['package'] |
| 137 | top_types = descriptor.get('message_type', []) |
| 138 | for full_name, message_type in iterate_for_type_in(top_types, package): |
| 139 | yield full_name, message_type |
| 140 | |
| 141 | |
| 142 | def iterate_enum_types(descriptors): |
| 143 | for descriptor in descriptors: |
| 144 | package = descriptor['package'] |
| 145 | for enum in descriptor.get('enum_type', []): |
| 146 | enum_full_name = package + '.' + enum['name'] |
| 147 | yield enum_full_name, enum |
| 148 | top_types = descriptor.get('message_type', []) |
| 149 | for full_name, message_type in iterate_for_type_in(top_types, package): |
| 150 | for enum in message_type.get('enum_type', []): |
| 151 | enum_full_name = full_name + '.' + enum['name'] |
| 152 | yield enum_full_name, enum |
| 153 | |
| 154 | |
| 155 | def generate_definitions(types_referenced, types): |
| 156 | """Walk all the referenced types and for each, generate a JSON schema |
| 157 | definition. These may also refer to other types, so keep the needed |
| 158 | set up-to-date. |
| 159 | """ |
| 160 | definitions = {} |
| 161 | wanted = copy(types_referenced) |
| 162 | while wanted: |
| 163 | full_name = wanted.pop() |
| 164 | type = types[full_name] |
| 165 | definition, types_referenced = make_definition(type, types) |
| 166 | definitions[full_name] = definition |
| 167 | for type_referenced in types_referenced: |
| 168 | if not definitions.has_key(type_referenced): |
| 169 | wanted.add(type_referenced) |
| 170 | return definitions |
| 171 | |
| 172 | |
| 173 | def make_definition(type, types): |
| 174 | if type['_type'] == 'google.protobuf.EnumDescriptorProto': |
| 175 | return make_enum_definition(type), set() |
| 176 | else: |
| 177 | return make_object_definition(type, types) |
| 178 | |
| 179 | |
| 180 | def make_enum_definition(type): |
| 181 | |
| 182 | def make_value_desc(enum_value): |
| 183 | txt = ' - {}'.format(enum_value['name']) |
| 184 | description = enum_value.get('_description', '') |
| 185 | if description: |
| 186 | txt += ': {}'.format(description) |
| 187 | return txt |
| 188 | |
| 189 | string_values = [v['name'] for v in type['value']] |
| 190 | default = type['value'][0]['name'] |
| 191 | description = ( |
| 192 | (type.get('_description', '') or type['name']) |
| 193 | + '\nValid values:\n' |
| 194 | + '\n'.join(make_value_desc(v) for v in type['value']) |
| 195 | ) |
| 196 | |
| 197 | definition = { |
| 198 | 'type': 'string', |
| 199 | 'enum': string_values, |
| 200 | 'default': default, |
| 201 | 'description': description |
| 202 | } |
| 203 | |
| 204 | return definition |
| 205 | |
| 206 | |
| 207 | def make_object_definition(type, types): |
| 208 | |
| 209 | definition = { |
| 210 | 'type': 'object' |
| 211 | } |
| 212 | |
| 213 | referenced = set() |
| 214 | properties = {} |
| 215 | for field in type.get('field', []): |
| 216 | field_name, property, referenced_by_field = make_property(field, types) |
| 217 | properties[field_name] = property |
| 218 | referenced.update(referenced_by_field) |
| 219 | |
| 220 | if properties: |
| 221 | definition['properties'] = properties |
| 222 | |
| 223 | if type.has_key('_description'): |
| 224 | definition['description'] = type['_description'] |
| 225 | |
| 226 | return definition, referenced |
| 227 | |
| 228 | |
| 229 | def make_property(field, types): |
| 230 | |
| 231 | referenced = set() |
| 232 | |
| 233 | repeated = field['label'] == FieldDescriptor.LABEL_REPEATED |
| 234 | |
| 235 | def check_if_map_entry(type_name): |
| 236 | type = types[type_name] |
| 237 | if type.get('options', {}).get('map_entry', False): |
| 238 | _, property, __ = make_property(type['field'][1], types) |
| 239 | return property |
| 240 | |
| 241 | if field['type'] == FieldDescriptor.TYPE_MESSAGE: |
| 242 | |
| 243 | type_name = field['type_name'].strip('.') |
| 244 | |
| 245 | maybe_map_value_type = check_if_map_entry(type_name) |
| 246 | if maybe_map_value_type: |
| 247 | # map-entries are inlined |
| 248 | repeated = False |
| 249 | property = { |
| 250 | 'type': 'object', |
| 251 | 'additionalProperties': maybe_map_value_type |
| 252 | } |
| 253 | |
| 254 | elif type_name == 'google.protobuf.Timestamp': |
| 255 | # time-stamp is mapped back to JSON schema date-time string |
| 256 | property = { |
| 257 | 'type': 'string', |
| 258 | 'format': 'date-time' |
| 259 | } |
| 260 | |
| 261 | else: |
| 262 | # normal nested object field |
| 263 | property = { |
| 264 | '$ref': '#/definitions/{}'.format(type_name) |
| 265 | } |
| 266 | referenced.add(type_name) |
| 267 | |
| 268 | elif field['type'] == FieldDescriptor.TYPE_ENUM: |
| 269 | type_name = field['type_name'].strip('.') |
| 270 | property = { |
| 271 | '$ref': '#/definitions/{}'.format(type_name) |
| 272 | } |
| 273 | referenced.add(type_name) |
| 274 | |
| 275 | elif field['type'] == FieldDescriptor.TYPE_GROUP: |
| 276 | raise NotImplementedError() |
| 277 | |
| 278 | else: |
| 279 | _type, format = TYPE_MAP[field['type']] |
| 280 | property = { |
| 281 | 'type': _type, |
| 282 | 'format': format |
| 283 | } |
| 284 | |
| 285 | if repeated: |
| 286 | property = { |
| 287 | 'type': 'array', |
| 288 | 'items': property |
| 289 | } |
| 290 | |
| 291 | if field.has_key('_description'): |
| 292 | property['description'] = field['_description'] |
| 293 | |
| 294 | return field['name'], property, referenced |
| 295 | |
| 296 | |
| 297 | def generate_paths(methods_dict, definitions): |
| 298 | |
| 299 | paths = {} |
| 300 | |
| 301 | def _iterate(): |
| 302 | for full_name, (service, method) in methods_dict.iteritems(): |
| 303 | http_option = method['options']['http'] |
| 304 | yield service, method, http_option |
| 305 | for binding in http_option.get('additional_bindings', []): |
| 306 | yield service, method, binding |
| 307 | |
| 308 | def prune_path(path): |
| 309 | """rid '=<stuff>' pattern from path symbolic segments""" |
| 310 | segments = re_segment.findall(path) |
| 311 | pruned_segments = [] |
| 312 | for absolute, symbolic in segments: |
| 313 | if symbolic: |
| 314 | full_symbol = symbolic[1:-1] |
| 315 | pruned_symbol = full_symbol.split('=', 2)[0] |
| 316 | pruned_segments.append('{' + pruned_symbol + '}') |
| 317 | else: |
| 318 | pruned_segments.append(absolute) |
| 319 | |
| 320 | return '/' + '/'.join(pruned_segments) |
| 321 | |
| 322 | def lookup_input_type(input_type_name): |
| 323 | return definitions[input_type_name.strip('.')] |
| 324 | |
| 325 | def lookup_type(input_type, field_name): |
| 326 | local_field_name, _, rest = field_name.partition('.') |
| 327 | properties = input_type['properties'] |
| 328 | if not properties.has_key(local_field_name): |
| 329 | raise InvalidPathArgumentError( |
| 330 | 'Input type has no field {}'.format(field_name)) |
| 331 | field = properties[local_field_name] |
| 332 | if rest: |
| 333 | field_type = field.get('type', 'object') |
| 334 | assert field_type == 'object', ( |
| 335 | 'Nested field name "%s" refers to field that of type "%s" ' |
| 336 | '(.%s should be nested object field)' |
| 337 | % (field_name, field_type, local_field_name)) |
| 338 | ref = field['$ref'] |
| 339 | assert ref.startswith('#/definitions/') |
| 340 | type_name = ref.replace('#/definitions/', '') |
| 341 | nested_input_type = lookup_input_type(type_name) |
| 342 | return lookup_type(nested_input_type, rest) |
| 343 | else: |
| 344 | return field['type'], field['format'] |
| 345 | |
| 346 | def make_entry(service, method, http): |
| 347 | parameters = [] |
| 348 | verb = None |
| 349 | for verb_candidate in ('get', 'delete', 'patch', 'post', 'put'): |
| 350 | if verb_candidate in http: |
| 351 | verb, path = verb_candidate, http[verb_candidate] |
| 352 | break |
| 353 | if 'custom' in http: |
| 354 | assert verb is None |
| 355 | verb = http['custom']['kind'] |
| 356 | path = http['custom']['path'] |
| 357 | assert verb is not None |
| 358 | path = prune_path(path) |
| 359 | |
| 360 | # for each symbolic segment in path, add a path parameter entry |
| 361 | input_type = lookup_input_type(method['input_type']) |
| 362 | for segment in re_path_param.findall(path): |
| 363 | symbol = segment.split('=')[0] |
| 364 | _type, format = lookup_type(input_type, symbol) |
| 365 | parameters.append({ |
| 366 | 'in': 'path', |
| 367 | 'name': symbol, |
| 368 | 'required': True, |
| 369 | 'type': _type, |
| 370 | 'format': format |
| 371 | }) |
| 372 | |
| 373 | if 'body' in http: |
| 374 | if 'body' in http: # TODO validate if body lists fields |
| 375 | parameters.append({ |
| 376 | 'in': 'body', |
| 377 | 'name': 'body', |
| 378 | 'required': True, |
| 379 | 'schema': {'$ref': '#/definitions/{}'.format( |
| 380 | method['input_type'].strip('.'))} |
| 381 | }) |
| 382 | |
| 383 | entry = { |
| 384 | 'operationId': method['name'], |
| 385 | 'tags': [service['name'],], |
| 386 | 'responses': { |
| 387 | '200': { # TODO: code is 201 and 209 in POST/DELETE? |
| 388 | 'description': unicode(""), # TODO: ever filled by proto? |
| 389 | 'schema': { |
| 390 | '$ref': '#/definitions/{}'.format( |
| 391 | method['output_type'].strip('.')) |
| 392 | } |
| 393 | }, |
| 394 | # TODO shall we prefill with standard error (verb specific), |
| 395 | # such as 400, 403, 404, 409, 509, 500, 503 etc. |
| 396 | } |
| 397 | } |
| 398 | |
| 399 | if parameters: |
| 400 | entry['parameters'] = parameters |
| 401 | |
| 402 | summary, description = extract_summary_and_description(method) |
| 403 | if summary: |
| 404 | entry['summary'] = summary |
| 405 | if description: |
| 406 | entry['description'] = description |
| 407 | |
| 408 | return path, verb, entry |
| 409 | |
| 410 | for service, method, http in _iterate(): |
| 411 | path, verb, entry = make_entry(service, method, http) |
| 412 | path_dict = paths.setdefault(path, {}) |
| 413 | if verb in path_dict: |
| 414 | raise DuplicateMethodAndPathError( |
| 415 | 'There is already a {} method defined for path ({})'.format( |
| 416 | verb, path)) |
| 417 | path_dict[verb] = entry |
| 418 | |
| 419 | return paths |
| 420 | |
| 421 | |
| 422 | def extract_summary_and_description(obj): |
| 423 | """ |
| 424 | Break raw _description field (if present) into a summary line and/or |
| 425 | detailed description text as follows: |
| 426 | * if text is a single line (not counting white-spaces), then it is a |
| 427 | summary and there is no detailed description. |
| 428 | * if text starts with a non-empty line followied by an empty line followed |
| 429 | by at least one non-empty line, that the 1s line is the summary and the |
| 430 | lines after the empty line is the description. |
| 431 | * in all other cases the text is considered a description and no summary |
| 432 | is generated. |
| 433 | """ |
| 434 | assert isinstance(obj, dict) |
| 435 | summary, description = None, None |
| 436 | text = obj.get('_description', '') |
| 437 | if text: |
| 438 | s, blank, d = (text.split('\n', 2) + ['', ''])[:3] # so we can demux |
| 439 | if not blank.strip(): |
| 440 | summary = s |
| 441 | if d.strip(): |
| 442 | description = d |
| 443 | else: |
| 444 | description = text |
| 445 | |
| 446 | return summary, description |
| 447 | |
| 448 | |
| 449 | TYPE_MAP = { |
| 450 | FieldDescriptor.TYPE_BOOL: ('boolean', 'boolean'), |
| 451 | FieldDescriptor.TYPE_BYTES: ('string', 'byte'), |
| 452 | FieldDescriptor.TYPE_DOUBLE: ('number', 'double'), |
| 453 | FieldDescriptor.TYPE_ENUM: ('string', 'string'), |
| 454 | FieldDescriptor.TYPE_FIXED32: ('integer', 'int64'), |
| 455 | FieldDescriptor.TYPE_FIXED64: ('string', 'uint64'), |
| 456 | FieldDescriptor.TYPE_FLOAT: ('number', 'float'), |
| 457 | FieldDescriptor.TYPE_INT32: ('integer', 'int32'), |
| 458 | FieldDescriptor.TYPE_INT64: ('string', 'int64'), |
| 459 | FieldDescriptor.TYPE_SFIXED32: ('integer', 'int32'), |
| 460 | FieldDescriptor.TYPE_SFIXED64: ('string', 'int64'), |
| 461 | FieldDescriptor.TYPE_STRING: ('string', 'string'), |
| 462 | FieldDescriptor.TYPE_SINT32: ('integer', 'int32'), |
| 463 | FieldDescriptor.TYPE_SINT64: ('string', 'int64'), |
| 464 | FieldDescriptor.TYPE_UINT32: ('integer', 'int64'), |
| 465 | FieldDescriptor.TYPE_UINT64: ('string', 'uint64'), |
| 466 | # FieldDescriptor.TYPE_MESSAGE: |
| 467 | # FieldDescriptor.TYPE_GROUP: |
| 468 | } |