This commit consists of:
1) Differentiate between yang containers and groupings
2) Handle repeated fields
3) Add yang list to encapsulate reference to message types
Change-Id: Ideb5bf8c6ff847a580b5e27339f238d463c091f2
diff --git a/experiments/proto2Yang/proto2yang.py b/experiments/proto2Yang/proto2yang.py
index 9b039fd..88ce1a7 100755
--- a/experiments/proto2Yang/proto2yang.py
+++ b/experiments/proto2Yang/proto2yang.py
@@ -56,20 +56,68 @@
description "Initial revision.";
{% endif %}
}
+
+ {% for enum in module.enums %}
+ {% if enum.description %}
+ /* {{ enum.description }} */
+ {% endif %}
+ typedef {{ enum.name }} {
+ type enumeration {
+ {% for v in enum.value %}
+ {% if v.description %}
+ enum {{ v.name }} {
+ description "{{ v.description }}";
+ }
+ {% else %}
+ enum {{ v.name }} ;
+ {% endif %}
+ {% endfor %}
+ }
+ }
+ {% endfor %}
+
{% for message in module.messages recursive %}
+ {% if message.name in module.referenced_messages %}
grouping {{ message.name }} {
+ {% else %}
+ container {{ message.name }} {
+ {% endif %}
{% if message.description %}
/* {{ message.description }} */
{% endif %}
- {% if message.key %}
- key {{ message.key_name }} ;
- {% endif %}
{% for field in message.fields %}
{% if field.type_ref %}
{% if field.description %}
/* {{ field.description }} */
{% endif %}
- uses {{ field.type }} ;
+ {% for dict_item in module.referred_messages_with_keys %}
+ {% if dict_item.name == field.type %}
+ list {{ field.name }} {
+ key "{{ dict_item.key }}";
+ {% if not field.repeated %}
+ max-elements 1;
+ {% endif %}
+ uses {{ field.type }};
+ }
+ {% endif %}
+ {% endfor %}
+ {% elif field.repeated %}
+ list {{ field.name }} {
+ key "{{ field.name }}";
+ leaf {{ field.name }} {
+ {% if field.type == "decimal64" %}
+ type {{ field.type }} {
+ fraction-digits 5;
+ }
+ {% else %}
+ type {{ field.type }} ;
+ {% endif %}
+ {% if field.description %}
+ description
+ "{{ field.description }}" ;
+ {% endif %}
+ }
+ }
{% else %}
leaf {{ field.name }} {
{% if field.type == "decimal64" %}
@@ -85,10 +133,11 @@
{% endif %}
}
{% endif %}
+
{% endfor %}
{% for enum_type in message.enums %}
{% if enum_type.description %}
- /* {{ enum_type.description }} */ ;
+ /* {{ enum_type.description }} */
{% endif %}
typedef {{ enum_type.name }} {
type enumeration {
@@ -103,22 +152,18 @@
{% endfor %}
}
}
- {% endfor %}
- {% for oneof in message.oneof %}
- choice {{ oneof.name }} {
- }
+
{% endfor %}
{% if message.messages %}
{{ loop (message.messages)|indent(4, false) }}
{% endif %}
}
- {% endfor %}
+ {% endfor %}
{% for service in module.services %}
{% if service.description %}
/* {{ service.description }}" */
{% endif %}
-
{% for method in service.methods %}
{% if method.description %}
/* {{ method.description }} */
@@ -147,33 +192,42 @@
}
{% endif %}
}
+
{% endfor %}
+
{% endfor %}
}
""", trim_blocks=True, lstrip_blocks=True)
-def _traverse_messages(message_types):
+def traverse_messages(message_types, prefix, referenced_messages):
messages = []
for message_type in message_types:
assert message_type['_type'] == 'google.protobuf.DescriptorProto'
+
+ # full_name = prefix + '-' + message_type['name']
+ full_name = message_type['name']
+
# parse the fields
- fields = _traverse_fields(message_type.get('field', []))
+ fields = traverse_fields(message_type.get('field', []), full_name,
+ referenced_messages)
# parse the enums
- enums = _traverse_enums(message_type.get('enum_type', []))
+ enums = traverse_enums(message_type.get('enum_type', []), full_name)
# parse nested messages
nested = message_type.get('nested_type', [])
- nested_messages = _traverse_messages(nested)
+ nested_messages = traverse_messages(nested, full_name,
+ referenced_messages)
messages.append(
{
- 'name': message_type.get('name', ''),
+ 'name': full_name,
'fields': fields,
'enums': enums,
# 'extensions': extensions,
'messages': nested_messages,
- 'description': message_type.get('_description', ''),
+ 'description': remove_unsupported_characters(
+ message_type.get('_description', '')),
# 'extension_ranges': extension_ranges,
# 'oneof': oneof
}
@@ -181,19 +235,26 @@
return messages
-def _traverse_fields(fields_desc):
+def traverse_fields(fields_desc, prefix, referenced_messages):
fields = []
for field in fields_desc:
assert field['_type'] == 'google.protobuf.FieldDescriptorProto'
+ yang_base_type = is_base_type(field['type'])
+ _type = get_yang_type(field)
+ if not yang_base_type:
+ referenced_messages.append(_type)
+
fields.append(
{
+ # 'name': prefix + '-' + field.get('name', ''),
'name': field.get('name', ''),
'label': field.get('label', ''),
+ 'repeated': field['label'] == FieldDescriptor.LABEL_REPEATED,
'number': field.get('number', ''),
'options': field.get('options', ''),
'type_name': field.get('type_name', ''),
- 'type': get_yang_type(field),
- 'type_ref': not is_base_type(field['type']),
+ 'type': _type,
+ 'type_ref': not yang_base_type,
'description': remove_unsupported_characters(field.get(
'_description', ''))
}
@@ -201,36 +262,48 @@
return fields
-def _traverse_enums(enums_desc):
+def traverse_enums(enums_desc, prefix):
enums = []
for enum in enums_desc:
assert enum['_type'] == 'google.protobuf.EnumDescriptorProto'
+ # full_name = prefix + '-' + enum.get('name', '')
+ full_name = enum.get('name', '')
enums.append(
{
- 'name': enum.get('name', ''),
+ 'name': full_name,
'value': enum.get('value', ''),
- 'description': enum.get('_description', '')
+ 'description': remove_unsupported_characters(enum.get(
+ '_description', ''))
}
)
return enums
-def _traverse_services(service_desc):
+def traverse_services(service_desc, referenced_messages):
services = []
for service in service_desc:
methods = []
for method in service.get('method', []):
assert method['_type'] == 'google.protobuf.MethodDescriptorProto'
+
input_name = method.get('input_type')
input_ref = False
if not is_base_type(input_name):
+ input_name = remove_first_character_if_match(input_name, '.')
+ # input_name = input_name.replace(".", "-")
input_name = input_name.split('.')[-1]
+ referenced_messages.append(input_name)
input_ref = True
+
output_name = method.get('output_type')
output_ref = False
if not is_base_type(output_name):
+ output_name = remove_first_character_if_match(output_name, '.')
+ # output_name = output_name.replace(".", "-")
output_name = output_name.split('.')[-1]
+ referenced_messages.append(output_name)
output_ref = True
+
methods.append(
{
'method': method.get('name', ''),
@@ -238,7 +311,8 @@
'input_ref': input_ref,
'output': output_name,
'output_ref': output_ref,
- 'description': method.get('_description', ''),
+ 'description': remove_unsupported_characters(method.get(
+ '_description', '')),
'server_streaming': method.get('server_streaming',
False) == True
}
@@ -247,42 +321,85 @@
{
'service': service.get('name', ''),
'methods': methods,
- 'description': service.get('_description', ''),
+ 'description': remove_unsupported_characters(service.get(
+ '_description', '')),
}
)
return services
-def _rchop(thestring, ending):
+def rchop(thestring, ending):
if thestring.endswith(ending):
return thestring[:-len(ending)]
return thestring
-def _traverse_desc(descriptor):
- name = _rchop(descriptor.get('name', ''), '.proto')
+def traverse_desc(descriptor):
+ referenced_messages = []
+ name = rchop(descriptor.get('name', ''), '.proto')
package = descriptor.get('package', '')
description = descriptor.get('_description', '')
- messages = _traverse_messages(descriptor.get('message_type', []))
- enums = _traverse_enums(descriptor.get('enum_type', []))
- services = _traverse_services(descriptor.get('service', []))
+ messages = traverse_messages(descriptor.get('message_type', []),
+ package, referenced_messages)
+ enums = traverse_enums(descriptor.get('enum_type', []), package)
+ services = traverse_services(descriptor.get('service', []),
+ referenced_messages)
# extensions = _traverse_extensions(descriptors)
# options = _traverse_options(descriptors)
+ set_messages_keys(messages)
+ unique_referred_messages_with_keys = []
+ for message_name in list(set(referenced_messages)):
+ unique_referred_messages_with_keys.append(
+ {
+ 'name': message_name,
+ 'key': get_message_key(message_name, messages)
+ }
+ )
data = {
'name': name,
'package': package,
- 'description' : description,
+ 'description': description,
'messages': messages,
'enums': enums,
'services': services,
+ 'referenced_messages': list(set(referenced_messages)),
+ # TODO: simplify for easier jinja2 template use
+ 'referred_messages_with_keys': unique_referred_messages_with_keys,
# 'extensions': extensions,
# 'options': options
}
-
return data
+def set_messages_keys(messages):
+ for message in messages:
+ message['key'] = _get_message_key(message)
+ if message['messages']:
+ set_messages_keys(message['messages'])
+
+
+def _get_message_key(message):
+ # assume key is first yang base type field
+ for field in message['fields']:
+ if not field['type_ref']:
+ return field['name']
+ # no key yet - search nested messaged
+ if message['messages']:
+ return get_message_key(message['messages'])
+ else:
+ return None
+
+
+def get_message_key(message_name, messages):
+ for message in messages:
+ if message_name == message['name']:
+ return message['key']
+ if message['messages']:
+ return get_message_key(message_name, message['messages'])
+ return None
+
+
def generate_code(request, response):
assert isinstance(request, plugin.CodeGeneratorRequest)
@@ -295,10 +412,10 @@
fold_comments=True)
# print native_data
- yang_data = _traverse_desc(native_data)
+ yang_data = traverse_desc(native_data)
f = response.file.add()
- #TODO: We should have a separate file for each output. There is an
+ # TODO: We should have a separate file for each output. There is an
# issue reusing the same filename with an incremental suffix. Using
# a different file name works but not the actual proto file name
f.name = proto_file.name.replace('.proto', '.yang')
@@ -307,17 +424,21 @@
# idx += 1
f.content = template_yang.render(module=yang_data)
+
def get_yang_type(field):
type = field['type']
if type in YANG_TYPE_MAP.keys():
_type, _ = YANG_TYPE_MAP[type]
if _type in ['enumeration', 'message', 'group']:
return field['type_name'].split('.')[-1]
+ # return remove_first_character_if_match(field['type_name'],
+ # '.').replace('.', '-')
else:
return _type
else:
return type
+
def is_base_type(type):
# check numeric value of the type first
if type in YANG_TYPE_MAP.keys():
@@ -325,13 +446,22 @@
return _type not in ['message', 'group']
else:
# proto name of the type
- result = [ _format for ( _ , _format) in YANG_TYPE_MAP.values() if
- _format == type and _format not in ['message', 'group']]
+ result = [_format for (_, _format) in YANG_TYPE_MAP.values() if
+ _format == type and _format not in ['message', 'group']]
return len(result) > 0
+
def remove_unsupported_characters(text):
- unsupported_characters = ["{", "}", "[", "]", "\"", "/", "\\"]
- return ''.join([i if i not in unsupported_characters else ' ' for i in text])
+ unsupported_characters = ["{", "}", "[", "]", "\"", "\\", "*", "/"]
+ return ''.join([i if i not in unsupported_characters else ' ' for i in
+ text])
+
+
+def remove_first_character_if_match(str, char):
+ if str.startswith(char):
+ return str[1:]
+ return str
+
YANG_TYPE_MAP = {
FieldDescriptor.TYPE_BOOL: ('boolean', 'boolean'),
@@ -373,4 +503,4 @@
# Write to stdout
sys.stdout.write(output)
- # print is_base_type(9)
\ No newline at end of file
+ # print is_base_type(9)