CORD-1568: Support sharing policies by letting one policy invoke another

Change-Id: I87fdb80a65d86b61364abd3ec35961668924c8f5
diff --git a/lib/xos-genx/tests/general_security_test.py b/lib/xos-genx/tests/general_security_test.py
index d5f3e79..1201baa 100644
--- a/lib/xos-genx/tests/general_security_test.py
+++ b/lib/xos-genx/tests/general_security_test.py
@@ -15,12 +15,16 @@
 """
 class XProtoSecurityTest(unittest.TestCase):
     def setUp(self):
-        self.target = XProtoTestHelpers.write_tmp_target("{{ xproto_fol_to_python_test('output', proto.policies.test_policy, None, '0') }}")
+        self.target = XProtoTestHelpers.write_tmp_target("""
+{% for name, policy in proto.policies.items() %}
+{{ xproto_fol_to_python_test(name, policy, None, '0') }}
+{% endfor %}
+""")
 
     def test_constant(self):
         xproto = \
 """
-    policy test_policy < True >
+    policy output < True >
 """
         args = FakeArgs()
         args.inputs = xproto
@@ -42,7 +46,7 @@
     def test_equal(self):
         xproto = \
 """
-    policy test_policy < ctx.user = obj.user >
+    policy output < ctx.user = obj.user >
 """
 
         args = FakeArgs()
@@ -66,10 +70,46 @@
 
         verdict = policy_output_enforcer(obj, ctx)
 
+    def test_call_policy(self):
+        xproto = \
+"""
+    policy sub_policy < ctx.user = obj.user >
+    policy output < *sub_policy(child) >
+"""
+
+        args = FakeArgs()
+        args.inputs = xproto
+        args.target = self.target
+
+        output = XOSGenerator.generate(args)
+
+        exec(output,globals()) # This loads the generated function, which should look like this:
+
+        """
+        def policy_sub_policy_enforcer(obj, ctx):
+            i1 = (ctx.user == obj.user)
+    	    return i1
+
+	def policy_output_enforcer(obj, ctx):
+	    i1 = policy_sub_policy_enforcer(obj.child, ctx)
+	    return i1
+        """
+
+        obj = FakeArgs()
+        obj.child = FakeArgs()
+	obj.child.user = 1
+
+        ctx = FakeArgs()
+	ctx.user = 1
+
+        verdict = policy_output_enforcer(obj, ctx)
+        self.assertTrue(verdict)
+
+
     def test_bin(self):
         xproto = \
 """
-    policy test_policy < ctx.is_admin = True | obj.empty = True>
+    policy output < ctx.is_admin = True | obj.empty = True>
 """
 
         args = FakeArgs()
@@ -101,7 +141,7 @@
     def test_exists(self):
         xproto = \
 """
-    policy test_policy < exists Privilege: Privilege.object_id = obj.id >
+    policy output < exists Privilege: Privilege.object_id = obj.id >
 """
 	args = FakeArgs()
         args.inputs = xproto
@@ -121,7 +161,7 @@
     def test_python(self):
         xproto = \
 """
-    policy test_policy < {{ "jack" in ["the", "box"] }} = False >
+    policy output < {{ "jack" in ["the", "box"] }} = False >
 """
 	args = FakeArgs()
         args.inputs = xproto
@@ -142,7 +182,7 @@
         # This one we only parse
         xproto = \
 """
-    policy test_policy < forall Credential: Credential.obj_id = obj_id >
+    policy output < forall Credential: Credential.obj_id = obj_id >
 """
 
         args = FakeArgs()
diff --git a/lib/xos-genx/tests/general_validation_test.py b/lib/xos-genx/tests/general_validation_test.py
index 4134fe1..c1e6820 100644
--- a/lib/xos-genx/tests/general_validation_test.py
+++ b/lib/xos-genx/tests/general_validation_test.py
@@ -16,12 +16,16 @@
 """
 class XProtoGeneralValidationTest(unittest.TestCase):
     def setUp(self):
-        self.target = XProtoTestHelpers.write_tmp_target("{{ xproto_fol_to_python_validator('output', proto.policies.test_policy, None, 'Necessary Failure') }}")
+        self.target = XProtoTestHelpers.write_tmp_target("""
+{% for name, policy in proto.policies.items() %}
+{{ xproto_fol_to_python_validator(name, policy, None, 'Necessary Failure') }}
+{% endfor %}
+""")
 
     def test_constant(self):
         xproto = \
 """
-    policy test_policy < False >
+    policy output < False >
 """
         args = FakeArgs()
         args.inputs = xproto
@@ -44,7 +48,7 @@
     def test_equal(self):
         xproto = \
 """
-    policy test_policy < not (ctx.user = obj.user) >
+    policy output < not (ctx.user = obj.user) >
 """
 
         args = FakeArgs()
@@ -74,7 +78,7 @@
     def test_equal(self):
         xproto = \
 """
-    policy test_policy < not (ctx.user = obj.user) >
+    policy output < not (ctx.user = obj.user) >
 """
 
         args = FakeArgs()
@@ -104,7 +108,7 @@
     def test_bin(self):
         xproto = \
 """
-    policy test_policy < (ctx.is_admin = True | obj.empty = True) & False>
+    policy output < (ctx.is_admin = True | obj.empty = True) & False>
 """
 
         args = FakeArgs()
@@ -136,7 +140,7 @@
     def test_exists(self):
         xproto = \
 """
-    policy test_policy < exists Privilege: Privilege.object_id = obj.id >
+    policy output < exists Privilege: Privilege.object_id = obj.id >
 """
 	args = FakeArgs()
         args.inputs = xproto
@@ -157,7 +161,7 @@
     def test_python(self):
         xproto = \
 """
-    policy test_policy < {{ "jack" in ["the", "box"] }} = True >
+    policy output < {{ "jack" in ["the", "box"] }} = True >
 """
 	args = FakeArgs()
         args.inputs = xproto
@@ -176,11 +180,48 @@
         with self.assertRaises(Exception):
             self.assertTrue(policy_output_validator({}, {}) is True)
 
+    def test_call_policy(self):
+        xproto = \
+"""
+    policy sub_policy < ctx.user = obj.user >
+    policy output < *sub_policy(child) >
+"""
+
+        args = FakeArgs()
+        args.inputs = xproto
+        args.target = self.target
+
+        output = XOSGenerator.generate(args)
+
+        exec(output,globals()) # This loads the generated function, which should look like this:
+
+        """
+        def policy_sub_policy_validator(obj, ctx):
+            i1 = (ctx.user == obj.user)
+            if (not i1):
+                raise ValidationError('Necessary Failure')
+
+        def policy_output_validator(obj, ctx):
+            i1 = policy_sub_policy_validator(obj.child, ctx)
+            if (not i1):
+                raise ValidationError('Necessary Failure')
+        """
+
+        obj = FakeArgs()
+        obj.child = FakeArgs()
+	obj.child.user = 1
+
+        ctx = FakeArgs()
+	ctx.user = 1
+
+        with self.assertRaises(Exception):
+            verdict = policy_output_enforcer(obj, ctx)
+
     def test_forall(self):
         # This one we only parse
         xproto = \
 """
-    policy test_policy < forall Credential: Credential.obj_id = obj_id >
+    policy output < forall Credential: Credential.obj_id = obj_id >
 """
 
         args = FakeArgs()
diff --git a/lib/xos-genx/tests/policy_test.py b/lib/xos-genx/tests/policy_test.py
index 25c307e..5bc3511 100644
--- a/lib/xos-genx/tests/policy_test.py
+++ b/lib/xos-genx/tests/policy_test.py
@@ -215,6 +215,41 @@
 
         self.assertTrue(eval(expr))
 
+    def test_policy_function(self):
+        xproto = \
+"""
+    policy slice_policy < exists Privilege: Privilege.object_id = obj.id >
+    policy network_slice_policy < *slice_policy(slice) >
+"""
+
+        target = XProtoTestHelpers.write_tmp_target("{{ proto.policies.network_slice_policy }} ")
+        args = FakeArgs()
+        args.inputs = xproto
+        args.target = target
+
+        output = XOSGenerator.generate(args)
+        
+        (op, operands), = eval(output).items()
+
+        self.assertIn('slice_policy', operands)
+        self.assertIn('slice', operands)
+
+    def test_policy_missing_function(self):
+        xproto = \
+"""
+    policy slice_policy < exists Privilege: Privilege.object_id = obj.id >
+    policy network_slice_policy < *slice_policyX(slice) >
+"""
+
+        target = XProtoTestHelpers.write_tmp_target("{{ proto.policies.network_slice_policy }} ")
+        args = FakeArgs()
+        args.inputs = xproto
+        args.target = target
+
+        with self.assertRaises(Exception):
+            output = XOSGenerator.generate(args)
+        
+
     def test_forall(self):
         # This one we only parse
         xproto = \
diff --git a/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py b/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
index ca39b17..f7af1dc 100644
--- a/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
+++ b/lib/xos-genx/xosgenx/jinja2_extensions/fol2.py
@@ -126,7 +126,7 @@
             k = 'term'
             v = fol
 
-        if k == 'python':
+        if k in ['python', 'policy'] :
             # Tainted, don't optimize
             if var:
                 return {'hoist': []}
@@ -202,8 +202,8 @@
         if not tag:
             tag = gen_random_string()
 
-        policy_function_name = 'policy_%(policy_name)s_%(random_string)s' % {
-            'policy_name': policy_name, 'random_string': tag}
+        policy_function_name_template = 'policy_%s_' + '%(random_string)s' % {'random_string': tag}
+        policy_function_name = policy_function_name_template % policy_name
         self.verdict_next()
         function_str = """
 def %(fn_name)s(obj, ctx):
@@ -211,7 +211,7 @@
         """ % {'fn_name': policy_function_name, 'vvar': self.verdict_variable, 'message': message}
 
         function_ast = self.str_to_ast(function_str)
-        policy_code = self.gen_test(fol, self.verdict_variable)
+        policy_code = self.gen_test(policy_function_name_template, fol, self.verdict_variable)
 
 
         function_ast.body = [policy_code] + function_ast.body
@@ -222,8 +222,9 @@
         if not tag:
             tag = gen_random_string()
 
-        policy_function_name = 'policy_%(policy_name)s_%(random_string)s' % {
-            'policy_name': policy_name, 'random_string': tag}
+        policy_function_name_template = 'policy_%s_' + '%(random_string)s' % {'random_string': tag}
+        policy_function_name = policy_function_name_template % policy_name
+
         self.verdict_next()
         function_str = """
 def %(fn_name)s(obj, ctx):
@@ -231,18 +232,28 @@
         """ % {'fn_name': policy_function_name, 'vvar': self.verdict_variable}
 
         function_ast = self.str_to_ast(function_str)
-        policy_code = self.gen_test(fol, self.verdict_variable)
+        policy_code = self.gen_test(policy_function_name_template, fol, self.verdict_variable)
 
         function_ast.body = [policy_code] + function_ast.body
 
         return function_ast
 
-    def gen_test(self, fol, verdict_var, bindings=None):
+    def gen_test(self, fn_template, fol, verdict_var, bindings=None):
         if isinstance(fol, str):
             return self.str_to_ast('%(verdict_var)s = %(constant)s' % {'verdict_var': verdict_var, 'constant': fol})
 
         (k, v), = fol.items()
 
+        if k == 'policy':
+            policy_name, object_name = v
+
+            policy_fn = fn_template % policy_name
+            call_str = """
+%(verdict_var)s = %(policy_fn)s(obj.%(object_name)s, ctx)
+            """ % {'verdict_var': self.verdict_variable, 'policy_fn': policy_fn, 'object_name': object_name}
+
+            call_ast = self.str_to_ast(call_str)
+            return call_ast
         if k == 'python':
             try:
                 expr_ast = self.str_to_ast(v)
@@ -263,7 +274,7 @@
             top_vvar = self.verdict_variable
             self.verdict_next()
             sub_vvar = self.verdict_variable
-            block = self.gen_test(v, sub_vvar)
+            block = self.gen_test(fn_template, v, sub_vvar)
             assignment_str = """
 %(verdict_var)s = not (%(subvar)s)
                     """ % {'verdict_var': top_vvar, 'subvar': sub_vvar}
@@ -333,8 +344,8 @@
             self.verdict_next()
             rvar = self.verdict_variable
 
-            lblock = self.gen_test(lhs, lvar)
-            rblock = self.gen_test(rhs, rvar)
+            lblock = self.gen_test(fn_template, lhs, lvar)
+            rblock = self.gen_test(fn_template, rhs, rvar)
 
             invert = ''
             if k == '&':
diff --git a/lib/xos-genx/xosgenx/xos2jinja.py b/lib/xos-genx/xosgenx/xos2jinja.py
index edd80ef..efcb995 100644
--- a/lib/xos-genx/xosgenx/xos2jinja.py
+++ b/lib/xos-genx/xosgenx/xos2jinja.py
@@ -7,7 +7,23 @@
 import jinja2
 import os
 import copy
+import pdb
 
+class MissingPolicyException(Exception):
+    pass
+
+def find_missing_policy_calls(name, policies, policy):
+    if isinstance(policy, dict):
+        (k, lst), = policy.items()
+        if k=='policy':
+            policy_name = lst[0]
+            if policy_name not in policies: raise MissingPolicyException("Policy %s invoked missing policy %s"%(name, policy_name))
+        else:
+            for p in lst:
+                find_missing_policy_calls(name, policies, p)
+    elif isinstance(policy, list):
+        for p in lst:
+            find_missing_policy_calls(name, policies, p)
 
 def dotname_to_fqn(dotname):
     b_names = [part.pval for part in dotname]
@@ -121,16 +137,7 @@
             pname = obj.name.value.pval
 
         self.policies[pname] = obj.body
-
-        return True
-
-    def visit_PolicyDefinition(self, obj):
-        if self.package:
-            pname = '.'.join([self.package, obj.name.value.pval])
-        else:
-            pname = obj.name.value.pval
-
-        self.policies[pname] = obj.body
+        find_missing_policy_calls(pname, self.policies, obj.body)
 
         return True