blob: 671bf281545dd37ea8151783349f97bd7af52e1f [file] [log] [blame]
Khen Nursimulu8ffb8932017-01-26 13:40:49 -05001#!/usr/bin/env python
2#
3# Copyright 2017 the original author or authors.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import sys
19
20from google.protobuf.compiler import plugin_pb2 as plugin
21from google.protobuf.descriptor_pb2 import ServiceDescriptorProto, \
22 MethodOptions
23from jinja2 import Template
24from simplejson import dumps
25
26from netconf.protos.third_party.google.api import annotations_pb2, http_pb2
27
28_ = annotations_pb2, http_pb2 # to keep import line from being optimized out
29
30template = Template("""
31# Generated file; please do not edit
32
33from simplejson import dumps, load
34from structlog import get_logger
35from google.protobuf.json_format import MessageToDict, ParseDict
36from twisted.internet.defer import inlineCallbacks, returnValue
37
38{% set package = file_name.replace('.proto', '') %}
39
40{% for pypackage, module in includes %}
41{% if pypackage %}
42from {{ pypackage }} import {{ module }}
43{% else %}
44import {{ module }}
45{% endif %}
46{% endfor %}
47
48log = get_logger()
49
50{% for method in methods %}
51{% set method_name = method['service'].rpartition('.')[2] + '_' + method['method'] %}
52@inlineCallbacks
53def {{ method_name }}(grpc_client, params, metadata, **kw):
54 log.info('{{ method_name }}', params=params, metadata=metadata, **kw)
55 data = params
56 data.update(kw)
57 try:
58 req = ParseDict(data, {{ type_map[method['input_type']] }}())
59 except Exception, e:
60 log.error('cannot-convert-to-protobuf', e=e, data=data)
61 raise
62 res, _ = yield grpc_client.invoke(
63 {{ type_map[method['service']] }}Stub,
64 '{{ method['method'] }}', req, metadata)
65 try:
66 out_data = grpc_client.convertToDict(res)
67 except AttributeError, e:
68 filename = '/tmp/netconf_failed_to_convert_data.pbd'
69 with file(filename, 'w') as f:
70 f.write(res.SerializeToString())
71 log.error('cannot-convert-from-protobuf', outdata_saved=filename)
72 raise
73 log.info('{{ method_name }}', **out_data)
74 returnValue(out_data)
75
76{% endfor %}
77
78""", trim_blocks=True, lstrip_blocks=True)
79
80
81def traverse_methods(proto_file):
82 package = proto_file.name
83 for service in proto_file.service:
84 assert isinstance(service, ServiceDescriptorProto)
85
86 for method in service.method:
87 input_type = method.input_type
88 if input_type.startswith('.'):
89 input_type = input_type[1:]
90
91 output_type = method.output_type
92 if output_type.startswith('.'):
93 output_type = output_type[1:]
94
95 data = {
96 'package': package,
97 'filename': proto_file.name,
98 'service': proto_file.package + '.' + service.name,
99 'method': method.name,
100 'input_type': input_type,
101 'output_type': output_type
102 }
103
104 yield data
105
106
107def generate_gw_code(file_name, methods, type_map, includes):
108 return template.render(file_name=file_name, methods=methods,
109 type_map=type_map, includes=includes)
110
111
112class IncludeManager(object):
113 # need to keep track of what files define what message types and
114 # under what package name. Later, when we analyze the methods, we
115 # need to be able to derive the list of files we need to load and we
116 # also need to replce the <proto-package-name>.<artifact-name> in the
117 # templates with <python-package-name>.<artifact-name> so Python can
118 # resolve these.
119 def __init__(self):
120 self.package_to_localname = {}
121 self.fullname_to_filename = {}
122 self.prefix_table = [] # sorted table of top-level symbols in protos
123 self.type_map = {} # full name as used in .proto -> python name
124 self.includes_needed = set() # names of files needed to be included
125 self.filename_to_module = {} # filename -> (package, module)
126
127 def extend_symbol_tables(self, proto_file):
128 # keep track of what file adds what top-level symbol to what abstract
129 # package name
130 package_name = proto_file.package
131 file_name = proto_file.name
132 self._add_filename(file_name)
133 all_defs = list(proto_file.message_type)
134 all_defs.extend(list(proto_file.enum_type))
135 all_defs.extend(list(proto_file.service))
136 for typedef in all_defs:
137 name = typedef.name
138 fullname = package_name + '.' + name
139 self.fullname_to_filename[fullname] = file_name
140 self.package_to_localname.setdefault(package_name, []).append(name)
141 self._update_prefix_table()
142
143 def _add_filename(self, filename):
144 if filename not in self.filename_to_module:
145 python_path = filename.replace('.proto', '_pb2').replace('/', '.')
146 package_name, _, module_name = python_path.rpartition('.')
147 self.filename_to_module[filename] = (package_name, module_name)
148
149 def _update_prefix_table(self):
150 # make a sorted list symbol prefixes needed to resolv for potential use
151 # of nested symbols
152 self.prefix_table = sorted(self.fullname_to_filename.iterkeys(),
153 reverse=True)
154
155 def _find_matching_prefix(self, fullname):
156 for prefix in self.prefix_table:
157 if fullname.startswith(prefix):
158 return prefix
159 # This should never happen
160 raise Exception('No match for type name "{}"'.format(fullname))
161
162 def add_needed_symbol(self, fullname):
163 if fullname in self.type_map:
164 return
165 top_level_symbol = self._find_matching_prefix(fullname)
166 name = top_level_symbol.rpartition('.')[2]
167 nested_name = fullname[len(top_level_symbol):] # may be empty
168 file_name = self.fullname_to_filename[top_level_symbol]
169 self.includes_needed.add(file_name)
170 module_name = self.filename_to_module[file_name][1]
171 python_name = module_name + '.' + name + nested_name
172 self.type_map[fullname] = python_name
173
174 def get_type_map(self):
175 return self.type_map
176
177 def get_includes(self):
178 return sorted(
179 self.filename_to_module[fn] for fn in self.includes_needed)
180
181
182def generate_code(request, response):
183 assert isinstance(request, plugin.CodeGeneratorRequest)
184
185 include_manager = IncludeManager()
186 for proto_file in request.proto_file:
187
188 include_manager.extend_symbol_tables(proto_file)
189
190 methods = []
191
192 for data in traverse_methods(proto_file):
193 methods.append(data)
194 include_manager.add_needed_symbol(data['input_type'])
195 include_manager.add_needed_symbol(data['output_type'])
196 include_manager.add_needed_symbol(data['service'])
197
198 type_map = include_manager.get_type_map()
199 includes = include_manager.get_includes()
200
201 # as a nice side-effect, generate a json file capturing the essence
202 # of the RPC method entries
203 f = response.file.add()
204 f.name = proto_file.name + '.json'
205 f.content = dumps(dict(
206 type_rename_map=type_map,
207 includes=includes,
208 methods=methods), indent=4)
209
210 # generate the real Python code file
211 f = response.file.add()
212 assert proto_file.name.endswith('.proto')
213 f.name = proto_file.name.replace('.proto', '_rpc_gw.py')
214 f.content = generate_gw_code(proto_file.name,
215 methods, type_map, includes)
216
217
218if __name__ == '__main__':
219
220 if len(sys.argv) >= 2:
221 # read input from file, to allow troubleshooting
222 with open(sys.argv[1], 'r') as f:
223 data = f.read()
224 else:
225 # read input from stdin
226 data = sys.stdin.read()
227
228 # parse request
229 request = plugin.CodeGeneratorRequest()
230 request.ParseFromString(data)
231
232 # create response object
233 response = plugin.CodeGeneratorResponse()
234
235 # generate the output and the response
236 generate_code(request, response)
237
238 # serialize the response
239 output = response.SerializeToString()
240
241 # write response to stdout
242 sys.stdout.write(output)