This commit consists of:
1) Differentiate between yang containers and groupings
2) Handle repeated fields
3) Add yang list to encapsulate reference to message types

Change-Id: Ideb5bf8c6ff847a580b5e27339f238d463c091f2
diff --git a/experiments/proto2Yang/proto2yang.py b/experiments/proto2Yang/proto2yang.py
index 9b039fd..88ce1a7 100755
--- a/experiments/proto2Yang/proto2yang.py
+++ b/experiments/proto2Yang/proto2yang.py
@@ -56,20 +56,68 @@
         description "Initial revision.";
     {% endif %}
     }
+
+    {% for enum in module.enums %}
+    {% if enum.description %}
+    /* {{ enum.description }} */
+    {% endif %}
+    typedef {{ enum.name }} {
+        type enumeration {
+        {% for v in enum.value %}
+        {% if v.description %}
+            enum {{ v.name }} {
+                description "{{ v.description }}";
+            }
+        {% else %}
+            enum {{ v.name }} ;
+        {% endif %}
+        {% endfor %}
+        }
+    }
+    {% endfor %}
+
     {% for message in module.messages recursive %}
+    {% if message.name in module.referenced_messages %}
     grouping {{ message.name }} {
+    {% else %}
+    container {{ message.name }} {
+    {% endif %}
         {% if message.description %}
         /* {{ message.description }} */
         {% endif %}
-        {% if message.key %}
-        key {{ message.key_name }} ;
-        {% endif %}
         {% for field in message.fields %}
         {% if field.type_ref %}
         {% if field.description %}
         /* {{ field.description }} */
         {% endif %}
-        uses {{ field.type }} ;
+        {% for dict_item in module.referred_messages_with_keys %}
+                {% if dict_item.name == field.type %}
+        list {{ field.name }} {
+            key "{{ dict_item.key }}";
+            {% if not field.repeated %}
+            max-elements 1;
+            {% endif %}
+            uses {{ field.type }};
+        }
+                {% endif %}
+        {% endfor %}
+        {% elif field.repeated %}
+        list {{ field.name }} {
+            key "{{ field.name }}";
+            leaf {{ field.name }} {
+                {% if field.type == "decimal64" %}
+                type {{ field.type }} {
+                   fraction-digits 5;
+                }
+                {% else %}
+                type {{ field.type }} ;
+                {% endif %}
+                {% if field.description %}
+                description
+                    "{{ field.description }}" ;
+                {% endif %}
+            }
+        }
         {% else %}
         leaf {{ field.name }} {
             {% if field.type == "decimal64" %}
@@ -85,10 +133,11 @@
             {% endif %}
         }
         {% endif %}
+
         {% endfor %}
         {% for enum_type in message.enums %}
         {% if enum_type.description %}
-        /* {{ enum_type.description }} */ ;
+        /* {{ enum_type.description }} */
         {% endif %}
         typedef {{ enum_type.name }} {
             type enumeration {
@@ -103,22 +152,18 @@
             {% endfor %}
             }
         }
-        {% endfor %}
-        {% for oneof in message.oneof %}
-        choice {{ oneof.name }} {
-        }
+
         {% endfor %}
     {% if message.messages %}
     {{ loop (message.messages)|indent(4, false) }}
     {% endif %}
     }
-    {% endfor %}
 
+    {% endfor %}
     {% for service in module.services %}
     {% if service.description %}
     /*  {{ service.description }}" */
     {% endif %}
-
     {% for method in service.methods %}
     {% if method.description %}
     /* {{ method.description }} */
@@ -147,33 +192,42 @@
         }
         {% endif %}
     }
+
     {% endfor %}
+
     {% endfor %}
 }
 """, trim_blocks=True, lstrip_blocks=True)
 
 
-def _traverse_messages(message_types):
+def traverse_messages(message_types, prefix, referenced_messages):
     messages = []
     for message_type in message_types:
         assert message_type['_type'] == 'google.protobuf.DescriptorProto'
+
+        # full_name = prefix + '-' + message_type['name']
+        full_name = message_type['name']
+
         # parse the fields
-        fields = _traverse_fields(message_type.get('field', []))
+        fields = traverse_fields(message_type.get('field', []), full_name,
+                                 referenced_messages)
 
         # parse the enums
-        enums = _traverse_enums(message_type.get('enum_type', []))
+        enums = traverse_enums(message_type.get('enum_type', []), full_name)
 
         # parse nested messages
         nested = message_type.get('nested_type', [])
-        nested_messages = _traverse_messages(nested)
+        nested_messages = traverse_messages(nested, full_name,
+                                            referenced_messages)
         messages.append(
             {
-                'name': message_type.get('name', ''),
+                'name': full_name,
                 'fields': fields,
                 'enums': enums,
                 # 'extensions': extensions,
                 'messages': nested_messages,
-                'description': message_type.get('_description', ''),
+                'description': remove_unsupported_characters(
+                    message_type.get('_description', '')),
                 # 'extension_ranges': extension_ranges,
                 # 'oneof': oneof
             }
@@ -181,19 +235,26 @@
     return messages
 
 
-def _traverse_fields(fields_desc):
+def traverse_fields(fields_desc, prefix, referenced_messages):
     fields = []
     for field in fields_desc:
         assert field['_type'] == 'google.protobuf.FieldDescriptorProto'
+        yang_base_type = is_base_type(field['type'])
+        _type = get_yang_type(field)
+        if not yang_base_type:
+            referenced_messages.append(_type)
+
         fields.append(
             {
+                # 'name': prefix + '-' + field.get('name', ''),
                 'name': field.get('name', ''),
                 'label': field.get('label', ''),
+                'repeated': field['label'] == FieldDescriptor.LABEL_REPEATED,
                 'number': field.get('number', ''),
                 'options': field.get('options', ''),
                 'type_name': field.get('type_name', ''),
-                'type': get_yang_type(field),
-                'type_ref': not is_base_type(field['type']),
+                'type': _type,
+                'type_ref': not yang_base_type,
                 'description': remove_unsupported_characters(field.get(
                     '_description', ''))
             }
@@ -201,36 +262,48 @@
     return fields
 
 
-def _traverse_enums(enums_desc):
+def traverse_enums(enums_desc, prefix):
     enums = []
     for enum in enums_desc:
         assert enum['_type'] == 'google.protobuf.EnumDescriptorProto'
+        # full_name = prefix + '-' + enum.get('name', '')
+        full_name = enum.get('name', '')
         enums.append(
             {
-                'name': enum.get('name', ''),
+                'name': full_name,
                 'value': enum.get('value', ''),
-                'description': enum.get('_description', '')
+                'description': remove_unsupported_characters(enum.get(
+                    '_description', ''))
             }
         )
     return enums
 
 
-def _traverse_services(service_desc):
+def traverse_services(service_desc, referenced_messages):
     services = []
     for service in service_desc:
         methods = []
         for method in service.get('method', []):
             assert method['_type'] == 'google.protobuf.MethodDescriptorProto'
+
             input_name = method.get('input_type')
             input_ref = False
             if not is_base_type(input_name):
+                input_name = remove_first_character_if_match(input_name, '.')
+                # input_name = input_name.replace(".", "-")
                 input_name = input_name.split('.')[-1]
+                referenced_messages.append(input_name)
                 input_ref = True
+
             output_name = method.get('output_type')
             output_ref = False
             if not is_base_type(output_name):
+                output_name = remove_first_character_if_match(output_name, '.')
+                # output_name = output_name.replace(".", "-")
                 output_name = output_name.split('.')[-1]
+                referenced_messages.append(output_name)
                 output_ref = True
+
             methods.append(
                 {
                     'method': method.get('name', ''),
@@ -238,7 +311,8 @@
                     'input_ref': input_ref,
                     'output': output_name,
                     'output_ref': output_ref,
-                    'description': method.get('_description', ''),
+                    'description': remove_unsupported_characters(method.get(
+                        '_description', '')),
                     'server_streaming': method.get('server_streaming',
                                                    False) == True
                 }
@@ -247,42 +321,85 @@
             {
                 'service': service.get('name', ''),
                 'methods': methods,
-                'description': service.get('_description', ''),
+                'description': remove_unsupported_characters(service.get(
+                    '_description', '')),
             }
         )
     return services
 
 
-def _rchop(thestring, ending):
+def rchop(thestring, ending):
     if thestring.endswith(ending):
         return thestring[:-len(ending)]
     return thestring
 
 
-def _traverse_desc(descriptor):
-    name = _rchop(descriptor.get('name', ''), '.proto')
+def traverse_desc(descriptor):
+    referenced_messages = []
+    name = rchop(descriptor.get('name', ''), '.proto')
     package = descriptor.get('package', '')
     description = descriptor.get('_description', '')
-    messages = _traverse_messages(descriptor.get('message_type', []))
-    enums = _traverse_enums(descriptor.get('enum_type', []))
-    services = _traverse_services(descriptor.get('service', []))
+    messages = traverse_messages(descriptor.get('message_type', []),
+                                 package, referenced_messages)
+    enums = traverse_enums(descriptor.get('enum_type', []), package)
+    services = traverse_services(descriptor.get('service', []),
+                                 referenced_messages)
     # extensions = _traverse_extensions(descriptors)
     # options = _traverse_options(descriptors)
+    set_messages_keys(messages)
+    unique_referred_messages_with_keys = []
+    for message_name in list(set(referenced_messages)):
+        unique_referred_messages_with_keys.append(
+            {
+                'name': message_name,
+                'key': get_message_key(message_name, messages)
+            }
+        )
 
     data = {
         'name': name,
         'package': package,
-        'description' : description,
+        'description': description,
         'messages': messages,
         'enums': enums,
         'services': services,
+        'referenced_messages': list(set(referenced_messages)),
+        # TODO:  simplify for easier jinja2 template use
+        'referred_messages_with_keys': unique_referred_messages_with_keys,
         # 'extensions': extensions,
         # 'options': options
     }
-
     return data
 
 
+def set_messages_keys(messages):
+    for message in messages:
+        message['key'] = _get_message_key(message)
+        if message['messages']:
+            set_messages_keys(message['messages'])
+
+
+def _get_message_key(message):
+    # assume key is first yang base type field
+    for field in message['fields']:
+        if not field['type_ref']:
+            return field['name']
+    # no key yet - search nested messaged
+    if message['messages']:
+        return get_message_key(message['messages'])
+    else:
+        return None
+
+
+def get_message_key(message_name, messages):
+    for message in messages:
+        if message_name == message['name']:
+            return message['key']
+        if message['messages']:
+            return get_message_key(message_name, message['messages'])
+    return None
+
+
 def generate_code(request, response):
     assert isinstance(request, plugin.CodeGeneratorRequest)
 
@@ -295,10 +412,10 @@
                                                    fold_comments=True)
 
         # print native_data
-        yang_data = _traverse_desc(native_data)
+        yang_data = traverse_desc(native_data)
 
         f = response.file.add()
-        #TODO: We should have a separate file for each output. There is an
+        # TODO: We should have a separate file for each output. There is an
         # issue reusing the same filename with an incremental suffix.  Using
         # a different file name works but not the actual proto file name
         f.name = proto_file.name.replace('.proto', '.yang')
@@ -307,17 +424,21 @@
         # idx += 1
         f.content = template_yang.render(module=yang_data)
 
+
 def get_yang_type(field):
     type = field['type']
     if type in YANG_TYPE_MAP.keys():
         _type, _ = YANG_TYPE_MAP[type]
         if _type in ['enumeration', 'message', 'group']:
             return field['type_name'].split('.')[-1]
+            # return remove_first_character_if_match(field['type_name'],
+            #                                        '.').replace('.', '-')
         else:
             return _type
     else:
         return type
 
+
 def is_base_type(type):
     # check numeric value of the type first
     if type in YANG_TYPE_MAP.keys():
@@ -325,13 +446,22 @@
         return _type not in ['message', 'group']
     else:
         # proto name of the type
-        result = [ _format for ( _ , _format) in YANG_TYPE_MAP.values() if
-                   _format == type and _format not in ['message', 'group']]
+        result = [_format for (_, _format) in YANG_TYPE_MAP.values() if
+                  _format == type and _format not in ['message', 'group']]
         return len(result) > 0
 
+
 def remove_unsupported_characters(text):
-    unsupported_characters = ["{", "}", "[", "]", "\"", "/", "\\"]
-    return ''.join([i if i not in unsupported_characters else ' ' for i in text])
+    unsupported_characters = ["{", "}", "[", "]", "\"", "\\", "*", "/"]
+    return ''.join([i if i not in unsupported_characters else ' ' for i in
+                    text])
+
+
+def remove_first_character_if_match(str, char):
+    if str.startswith(char):
+        return str[1:]
+    return str
+
 
 YANG_TYPE_MAP = {
     FieldDescriptor.TYPE_BOOL: ('boolean', 'boolean'),
@@ -373,4 +503,4 @@
 
     # Write to stdout
     sys.stdout.write(output)
-    # print is_base_type(9)
\ No newline at end of file
+    # print is_base_type(9)