blob: 4586a718972469179ede60d4f715419539f8e54c [file] [log] [blame]
Matteo Scandoloaf3c9942018-06-27 14:03:12 -07001
2# Copyright 2017-present Open Networking Foundation
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import unittest
18import ipaddress
19from mock import patch, call, Mock, PropertyMock
20
21import os, sys
22
23test_path=os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
24service_dir=os.path.join(test_path, "../../../..")
25xos_dir=os.path.join(test_path, "../../..")
26if not os.path.exists(os.path.join(test_path, "new_base")):
27 xos_dir=os.path.join(test_path, "../../../../../../orchestration/xos/xos")
28 services_dir=os.path.join(xos_dir, "../../xos_services")
29
30# While transitioning from static to dynamic load, the path to find neighboring xproto files has changed. So check
31# both possible locations...
32def get_models_fn(service_name, xproto_name):
33 name = os.path.join(service_name, "xos", xproto_name)
34 if os.path.exists(os.path.join(services_dir, name)):
35 return name
36 else:
37 name = os.path.join(service_name, "xos", "synchronizer", "models", xproto_name)
38 if os.path.exists(os.path.join(services_dir, name)):
39 return name
40 raise Exception("Unable to find service=%s xproto=%s" % (service_name, xproto_name))
41
42class TestComputeNodePolicy(unittest.TestCase):
43 def setUp(self):
44 global ComputeNodePolicy, MockObjectList
45
46 self.sys_path_save = sys.path
47 sys.path.append(xos_dir)
48 sys.path.append(os.path.join(xos_dir, 'synchronizers', 'new_base'))
49
50 config = os.path.join(test_path, "../test_config.yaml")
51 from xosconfig import Config
52 Config.clear()
53 Config.init(config, 'synchronizer-config-schema.yaml')
54
55 from synchronizers.new_base.mock_modelaccessor_build import build_mock_modelaccessor
56 build_mock_modelaccessor(xos_dir, services_dir, [get_models_fn("fabric", "fabric.xproto")])
57
58 import synchronizers.new_base.modelaccessor
59
60 from model_policy_compute_nodes import ComputeNodePolicy, model_accessor
61
62 from mock_modelaccessor import MockObjectList
63
64 # import all class names to globals
65 for (k, v) in model_accessor.all_model_classes.items():
66 globals()[k] = v
67
68 # Some of the functions we call have side-effects. For example, creating a VSGServiceInstance may lead to creation of
69 # tags. Ideally, this wouldn't happen, but it does. So make sure we reset the world.
70 model_accessor.reset_all_object_stores()
71
72 self.policy = ComputeNodePolicy
73 self.model = Mock()
74
75 def tearDown(self):
76 sys.path = self.sys_path_save
77
78 def test_getLastAddress(self):
79
80 dataPlaneIp = unicode("10.6.1.2/24", "utf-8")
81 interface = ipaddress.ip_interface(dataPlaneIp)
82 subnet = ipaddress.ip_network(interface.network)
83 last_ip = self.policy.getLastAddress(subnet)
84 self.assertEqual(str(last_ip), "10.6.1.254/24")
85
86 def test_generateVlan(self):
87
88 used_vlans = range(16, 4093)
89 used_vlans.remove(1000)
90
91 vlan = self.policy.generateVlan(used_vlans)
92
93 self.assertEqual(vlan, 1000)
94
95 def test_generateVlanFail(self):
96
97 used_vlans = range(16, 4093)
98
99 with self.assertRaises(Exception) as e:
100 self.policy.generateVlan(used_vlans)
101
102 self.assertEqual(e.exception.message, "No VLANs left")
103
104 def test_getVlanByCidr_same_subnet(self):
105
106 mock_pi_ip = unicode("10.6.1.2/24", "utf-8")
107
108 mock_pi = Mock()
109 mock_pi.vlanUntagged = 1234
110 mock_pi.ips = str(self.policy.getPortCidrByIp(mock_pi_ip))
111
112 test_ip = unicode("10.6.1.1/24", "utf-8")
113 test_subnet = self.policy.getPortCidrByIp(test_ip)
114
115 with patch.object(PortInterface.objects, "get_items") as get_pi:
116
117 get_pi.return_value = [mock_pi]
118 vlan = self.policy.getVlanByCidr(test_subnet)
119
120 self.assertEqual(vlan, mock_pi.vlanUntagged)
121
122 def test_getVlanByCidr_different_subnet(self):
123
124 mock_pi_ip = unicode("10.6.1.2/24", "utf-8")
125 mock_pi = Mock()
126 mock_pi.vlanUntagged = 1234
127 mock_pi.ips = str(self.policy.getPortCidrByIp(mock_pi_ip))
128
129 test_ip = unicode("192.168.1.1/24", "utf-8")
130 test_subnet = self.policy.getPortCidrByIp(test_ip)
131
132 with patch.object(PortInterface.objects, "get_items") as get_pi:
133
134 get_pi.return_value = [mock_pi]
135 vlan = self.policy.getVlanByCidr(test_subnet)
136
137 self.assertNotEqual(vlan, mock_pi.vlanUntagged)
138
139 def test_handle_create(self):
140
141 policy = self.policy()
142 with patch.object(policy, "handle_update") as handle_update:
143 policy.handle_create(self.model)
144 handle_update.assert_called_with(self.model)
145
146 def test_handle_update_do_nothing(self):
147
148 mock_pi_ip = unicode("10.6.1.2/24", "utf-8")
149 mock_pi = Mock()
150 mock_pi.port_id = 1
151 mock_pi.name = "test_interface"
152 mock_pi.ips = str(self.policy.getPortCidrByIp(mock_pi_ip))
153
154 policy = self.policy()
155
156 self.model.port.id = 1
157 self.model.node.dataPlaneIntf = "test_interface"
158
159 with patch.object(PortInterface.objects, "get_items") as get_pi, \
160 patch.object(self.policy, "getPortCidrByIp") as get_subnet, \
161 patch.object(PortInterface, 'save') as mock_save:
162
163 get_pi.return_value = [mock_pi]
164 get_subnet.return_value = mock_pi.ips
165
166 policy.handle_update(self.model)
167
168 mock_save.assert_not_called()
169
170 def test_handle_update(self):
171
172 policy = self.policy()
173
174 self.model.port.id = 1
175 self.model.node.dataPlaneIntf = "test_interface"
176 self.model.node.dataPlaneIp = unicode("10.6.1.2/24", "utf-8")
177
178 with patch.object(PortInterface.objects, "get_items") as get_pi, \
179 patch.object(self.policy, "getVlanByCidr") as get_vlan, \
180 patch.object(PortInterface, "save", autospec=True) as mock_save:
181
182 get_pi.return_value = []
183 get_vlan.return_value = "1234"
184
185 policy.handle_update(self.model)
186
187 self.assertEqual(mock_save.call_count, 1)
188 pi = mock_save.call_args[0][0]
189
190 self.assertEqual(pi.name, self.model.node.dataPlaneIntf)
191 self.assertEqual(pi.port_id, self.model.port.id)
192 self.assertEqual(pi.vlanUntagged, "1234")
193 self.assertEqual(pi.ips, "10.6.1.254/24")
194
195
196if __name__ == '__main__':
197 unittest.main()
198