CORD-1570: Several bug fixes, expanded unit tests for security refactoring

Change-Id: Ied8dca916d3c22a252f6de38a65ef1b20c9d639d
diff --git a/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py b/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
index f7af1dc..0c8513a 100644
--- a/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
+++ b/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
@@ -4,15 +4,18 @@
 import string
 import jinja2
 from plyxproto.parser import *
-import pdb
 
 BINOPS = ['|', '&', '->']
 QUANTS = ['exists', 'forall']
 
-
 class PolicyException(Exception):
     pass
 
+class ConstructNotHandled(Exception):
+    pass
+
+class TrivialPolicy(Exception):
+    pass
 
 class AutoVariable:
     def __init__(self, base):
@@ -27,11 +30,9 @@
         self.idx += 1
         return var
 
-
 def gen_random_string():
     return ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5))
 
-
 class FOL2Python:
     def __init__(self, context_map=None):
         # This will produce i0, i1, i2 etc.
@@ -114,25 +115,223 @@
 
     """ Convert a single leaf node from a string
         to an AST"""
-
     def str_to_ast(self, s):
         ast_module = ast.parse(s)
         return ast_module.body[0]
 
-    def hoist_constants(self, fol, var=None):
+    def reduce_operands(self, operands):
+        if operands[0] in ['True','False']: 
+            return (operands[0],operands[1])
+        elif operands[1] in ['True','False']: 
+            return (operands[1],operands[0])
+        else:
+            return None
+
+    """ Simplify binops with constants """
+    def simplify_binop(self, binop):
+        (k,v), = binop.items()
+        if k == '->':
+            lhs, rhs = v
+            if lhs == 'True':
+                return rhs
+            elif rhs == 'True':
+                return 'True'
+            elif lhs == 'False':
+                return 'True'
+            elif rhs == 'False':
+                return {'not': lhs}
+
+        var_expr = self.reduce_operands(v)
+
+        if not var_expr: return binop
+        else:
+            constant, var = var_expr
+            if k=='|':
+                if constant=='True':
+                    return 'True'
+                elif constant=='False':
+                    return var
+                else:
+                    raise Exception("Internal error - variable read as constant")
+            elif k=='&':
+                if constant=='True':
+                    return var
+                elif constant=='False':
+                    return 'False'
+
+    def is_constant(self, var, fol):
+        try:
+            (k, v), = fol.items()
+        except AttributeError:
+            k = 'term'
+            v = fol
+        
+        if k in ['python', 'policy']:
+           # Treat as a constant and hoist, since it cannot be quantified
+           return True
+        elif k == 'term':
+            return not v.startswith(var)
+        elif k == 'not':
+            return self.is_constant(var, fol)
+        elif k in ['in', '=']:
+            lhs, rhs = v
+            return self.is_constant(var,lhs) and self.is_constant(var, rhs)
+        elif k in BINOPS:
+            lhs, rhs = v
+            return self.is_constant(lhs, var) and self.is_constant(rhs, var)
+        elif k in QUANTS:
+            is_constant = self.is_constant(var, fol[1])
+            return is_constant
+        else:
+            raise ConstructNotHandled(k)
+
+    def find_constants(self, var, fol, constants):
         try:
             (k, v), = fol.items()
         except AttributeError:
             k = 'term'
             v = fol
 
-        if k in ['python', 'policy'] :
-            # Tainted, don't optimize
-            if var:
-                return {'hoist': []}
+        if k in ['python', 'policy']:
+           # Treat as a constant and hoist, since it cannot be quantified
+           if fol not in constants:
+               constants.append(fol)
+           return constants
+        elif k == 'term':
+           if not v.startswith(var):
+               constants.append(v)
+           return constants
+        elif k == 'not':
+            return self.find_constants(var, v, constants)
+        elif k in ['in', '=']:
+            lhs, rhs = v
+            if isinstance(lhs, str) and isinstance(rhs, str):
+                if not lhs.startswith(var) and not rhs.startswith(var):
+                    constants.append(fol)
+                return constants
             else:
-                return fol
+                constants = self.find_constants(var, lhs, constants)
+                return self.find_constants(var, rhs, constants)
+        elif k in BINOPS:
+            lhs, rhs = v
+            constants = self.find_constants(var, lhs, constants)
+            constants = self.find_constants(var, rhs, constants)
+            return constants
+        elif k in QUANTS:
+            is_constant = self.is_constant(var, v[1])
+            if is_constant: constants.append(fol)
+            return constants
+        else:
+            raise ConstructNotHandled(k)
 
+    """ Hoist constants out of quantifiers. Depth-first. """
+    def hoist_outer(self, fol):
+        try:
+            (k, v), = fol.items()
+        except AttributeError:
+            k = 'term'
+            v = fol
+
+        if k in ['python', 'policy']:
+           # Tainted, optimization and distribution not possible
+           return fol
+        elif k == 'term':
+           return fol
+        elif k == 'not':
+            vprime = self.hoist_outer(v)
+            return {'not': vprime}
+        elif k in ['in', '=']:
+            lhs, rhs = v
+            rlhs = self.hoist_outer(lhs)
+            rrhs = self.hoist_outer(rhs)
+            return {k:[rlhs,rrhs]}
+        elif k in BINOPS:
+            lhs, rhs = v
+            rlhs = self.hoist_outer(lhs)
+            rrhs = self.hoist_outer(rhs)
+
+            fol_prime = {k: [rlhs, rrhs]}
+            fol_simplified = self.simplify_binop(fol_prime)
+            return fol_simplified
+        elif k in QUANTS:
+            rexpr = self.hoist_outer(v[1])
+            return self.hoist_quant(k, [v[0],rexpr])
+        else:
+            raise ConstructNotHandled(k)
+
+    def replace_const(self, fol, c, value):
+        if fol == c:
+            return value
+
+        try:
+            (k, v), = fol.items()
+        except AttributeError:
+            k = 'term'
+            v = fol
+
+        if k == 'term':
+            if v == c: return value
+            else: return v
+        elif k == 'not':
+            new_expr = self.replace_const(v, c, value)
+            if new_expr=='True': 
+                return 'False'
+            elif new_expr=='False': 
+                return 'True'
+            else: 
+                return {'not': new_expr}
+        elif k in ['in', '=']:
+            lhs, rhs = v
+            rlhs = self.replace_const(lhs, c, value)
+            rrhs = self.replace_const(rhs, c, value)
+
+            if rlhs==rrhs:
+                return 'True'
+            else:
+                return {k: [rlhs, rrhs]}
+        elif k in BINOPS:
+            lhs, rhs = v
+
+            rlhs = self.replace_const(lhs, c, value)
+            rrhs = self.replace_const(rhs, c, value)
+        
+            return self.simplify_binop({k:[rlhs,rrhs]})
+        elif k in QUANTS:
+            var, expr = v
+            new_expr = self.replace_const(expr, c, value)
+            if new_expr in ['True', 'False']:
+                return new_expr
+            else:
+                return {k: [var, new_expr]}
+        else:
+            raise ConstructNotHandled(k)
+
+    def shannon_expand(self, c, fol):
+        lhs = self.replace_const(fol, c, 'True')
+        rhs = self.replace_const(fol, c, 'False')
+        not_c = {'not': c}
+        rlhs = {'&': [c, lhs]}
+        rlhs = self.simplify_binop(rlhs)
+
+        rrhs = {'&': [not_c, rhs]}
+        rrhs = self.simplify_binop(rrhs)
+
+        combined = {'|': [rlhs, rrhs]}
+        return self.simplify_binop(combined)
+
+    def hoist_quant(self, k, expr):
+        var, v = expr
+
+        constants = self.find_constants(var, v, constants=[])
+
+        fol = {k: expr}
+
+        for c in constants:
+            fol = self.shannon_expand(c, fol)
+
+        return fol
+
+        """
         if var:
             if k == 'term':
                 if not v.startswith(var):
@@ -197,6 +396,7 @@
                     return fol
             else:
                 return fol
+        """
 
     def gen_validation_function(self, fol, policy_name, message, tag):
         if not tag:
@@ -250,7 +450,7 @@
             policy_fn = fn_template % policy_name
             call_str = """
 %(verdict_var)s = %(policy_fn)s(obj.%(object_name)s, ctx)
-            """ % {'verdict_var': self.verdict_variable, 'policy_fn': policy_fn, 'object_name': object_name}
+            """ % {'verdict_var': verdict_var, 'policy_fn': policy_fn, 'object_name': object_name}
 
             call_ast = self.str_to_ast(call_str)
             return call_ast
@@ -266,12 +466,12 @@
 
             assignment_str = """
 %(verdict_var)s = (%(escape_expr)s)
-            """ % {'verdict_var': self.verdict_variable, 'escape_expr': v}
+            """ % {'verdict_var': verdict_var, 'escape_expr': v}
 
             assignment_ast = self.str_to_ast(assignment_str)
             return assignment_ast
         elif k == 'not':
-            top_vvar = self.verdict_variable
+            top_vvar = verdict_var
             self.verdict_next()
             sub_vvar = self.verdict_variable
             block = self.gen_test(fn_template, v, sub_vvar)
@@ -431,8 +631,13 @@
         raise Exception('Could not find policy:', policy)
 
     f2p = FOL2Python()
-    fol = f2p.hoist_constants(fol)
-    a = f2p.gen_test_function(fol, policy, tag='enforcer')
+    fol_reduced = f2p.hoist_outer(fol)
+
+    if fol_reduced in ['True','False'] and fol != fol_reduced:
+        raise TrivialPolicy("Policy %(name)s trivially reduces to %(reduced)s. If this is what you want, replace its contents with %(reduced)s"%{'name':policy, 'reduced':fol_reduced})
+
+    a = f2p.gen_test_function(fol_reduced, policy, tag='enforcer')
+
     return astunparse.unparse(a)
 
 def xproto_fol_to_python_validator(policy, fol, model, message, tag=None):
@@ -440,19 +645,30 @@
         raise Exception('Could not find policy:', policy)
 
     f2p = FOL2Python()
-    fol = f2p.hoist_constants(fol)
-    a = f2p.gen_validation_function(fol, policy, message, tag='validator')
+    fol_reduced = f2p.hoist_outer(fol)
+
+    if fol_reduced in ['True','False'] and fol != fol_reduced:
+        raise TrivialPolicy("Policy %(name)s trivially reduces to %(reduced)s. If this is what you want, replace its contents with %(reduced)s"%{'name':policy, 'reduced':fol_reduced})
+
+    a = f2p.gen_validation_function(fol_reduced, policy, message, tag='validator')
     
     return astunparse.unparse(a)
 
 def main():
     while True:
-        inp = raw_input()
+        inp = ''
+        while True:
+            inp_line = raw_input()
+            if inp_line=='EOF': break
+            else: inp+=inp_line
+            
         fol_lexer = lex.lex(module=FOLLexer())
         fol_parser = yacc.yacc(module=FOLParser(), start='goal')
 
         val = fol_parser.parse(inp, lexer=fol_lexer)
         a = xproto_fol_to_python_test('pol', val, 'output', 'Test')
+        #f2p = FOL2Python()
+        #a = f2p.hoist_outer(val)
         print a