More convenient & safe parent setting and visitor.
diff --git a/plyproto/model.py b/plyproto/model.py
index 5369fde..802cf47 100644
--- a/plyproto/model.py
+++ b/plyproto/model.py
@@ -45,9 +45,22 @@
def v(self, obj, visitor):
if obj == None:
return
- if not hasattr(obj, "accept"):
- return
- obj.accept(visitor)
+ elif hasattr(obj, "accept"):
+ obj.accept(visitor)
+ elif isinstance(obj, list):
+ for s in obj:
+ self.v(s, visitor)
+ pass
+ pass
+
+ @staticmethod
+ def p(obj, parent):
+ if isinstance(obj, list):
+ for s in obj:
+ Base.p(s, parent)
+
+ if hasattr(obj, "parent"):
+ obj.parent = parent
# Lexical unit - contains lexspan and linespan for later analysis.
class LU(Base):
@@ -132,7 +145,7 @@
super(PackageStatement, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
def accept(self, visitor):
visitor.visit_PackageStatement(self)
@@ -142,7 +155,7 @@
super(ImportStatement, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
def accept(self, visitor):
visitor.visit_ImportStatement(self)
@@ -152,9 +165,9 @@
super(OptionStatement, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'value']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.value = value
- self.value.parent = self
+ Base.p(self.value, self)
def accept(self, visitor):
visitor.visit_OptionStatement(self)
@@ -164,9 +177,9 @@
super(FieldDirective, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'value']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.value = value
- self.value.parent = self
+ Base.p(self.value, self)
def accept(self, visitor):
if visitor.visit_FieldDirective(self):
@@ -178,7 +191,7 @@
super(FieldType, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
def accept(self, visitor):
if visitor.visit_FieldType(self):
@@ -189,16 +202,15 @@
super(FieldDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['field_modifier', 'ftype', 'name', 'fieldId', 'fieldDirective']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.field_modifier = field_modifier
- self.field_modifier.parent = self
+ Base.p(self.field_modifier, self)
self.ftype = ftype
- self.ftype.parent = self
+ Base.p(self.ftype, self)
self.fieldId = fieldId
- self.fieldId.parent = self
+ Base.p(self.fieldId, self)
self.fieldDirective = fieldDirective
- for s in self.fieldDirective:
- s.parent = self
+ Base.p(self.fieldDirective, self)
def accept(self, visitor):
if visitor.visit_FieldDefinition(self):
@@ -206,17 +218,16 @@
self.v(self.field_modifier, visitor)
self.v(self.ftype, visitor)
self.v(self.fieldId, visitor)
- for s in self.fieldDirective:
- self.v(s, visitor)
+ self.v(self.fieldDirective, visitor)
class EnumFieldDefinition(SourceElement):
def __init__(self, name, fieldId, linespan=None, lexspan=None, p=None):
super(EnumFieldDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'fieldId']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.fieldId = fieldId
- self.fieldId.parent = self
+ Base.p(self.fieldId, self)
def accept(self, visitor):
if visitor.visit_EnumFieldDefinition(self):
@@ -228,59 +239,53 @@
super(EnumDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'body']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.body = body
- for s in self.body:
- s.parent = self
+ Base.p(self.body, self)
def accept(self, visitor):
if visitor.visit_EnumDefinition(self):
self.v(self.name, visitor)
- for s in self.body:
- s.accept(visitor)
+ self.v(self.body, visitor)
class MessageDefinition(SourceElement):
def __init__(self, name, body, linespan=None, lexspan=None, p=None):
super(MessageDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'body']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.body = body
- for s in self.body:
- s.parent = self
+ Base.p(self.body, self)
def accept(self, visitor):
if visitor.visit_MessageDefinition(self):
self.v(self.name, visitor)
- for s in self.body:
- s.accept(visitor)
+ self.v(self.body, visitor)
class MessageExtension(SourceElement):
def __init__(self, name, body, linespan=None, lexspan=None, p=None):
super(MessageExtension, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'body']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.body = body
- for s in self.body:
- s.parent = self
+ Base.p(self.body, self)
def accept(self, visitor):
if visitor.visit_MessageExtension(self):
self.v(self.name, visitor)
- for s in self.body:
- s.accept(visitor)
+ self.v(self.body, visitor)
class MethodDefinition(SourceElement):
def __init__(self, name, name2, name3, linespan=None, lexspan=None, p=None):
super(MethodDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'name2', 'name3']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.name2 = name2
- self.name.parent = self
+ Base.p(self.name, self)
self.name3 = name3
- self.name.parent = self
+ Base.p(self.name, self)
def accept(self, visitor):
if visitor.visit_MethodDefinition(self):
@@ -293,14 +298,14 @@
super(ServiceDefinition, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['name', 'body']
self.name = name
- self.name.parent = self
+ Base.p(self.name, self)
self.body = body
+ Base.p(self.body, self)
def accept(self, visitor):
if visitor.visit_ServiceDefinition(self):
self.v(self.name, visitor)
- for s in self.body:
- s.accept(visitor)
+ self.v(self.body, visitor)
class ExtensionsMax(SourceElement):
pass
@@ -310,9 +315,9 @@
super(ExtensionsDirective, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['fromVal', 'toVal']
self.fromVal = fromVal
- self.fromVal.parent = self
+ Base.p(self.fromVal, self)
self.toVal = toVal
- self.toVal.parent = self
+ Base.p(self.toVal, self)
def accept(self, visitor):
if visitor.visit_ExtensionsDirective(self):
@@ -380,13 +385,11 @@
super(ProtoFile, self).__init__(linespan=linespan, lexspan=lexspan, p=p)
self._fields += ['pkg', 'body']
self.pkg = pkg
- self.pkg.parent = self
+ Base.p(self.pkg, self)
self.body = body
- for s in self.body:
- s.parent = self
+ Base.p(self.body, self)
def accept(self, visitor):
if visitor.visit_Proto(self):
self.v(self.pkg, visitor)
- for s in self.body:
- s.accept(visitor)
+ self.v(self.body, visitor)