blob: 635310494c8cc539546ac3194f92dd4e21bfe2ae [file] [log] [blame]
Scott Bakerd85a25d2019-02-07 17:43:59 -08001# Copyright 2017-present Open Networking Foundation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import unittest
16from mock import patch, Mock, MagicMock
17
18from io import StringIO
19import functools
20import os
21import sys
22
23test_path = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
24
25
26def mock_listdir(dir_map, dir):
27 """ mock os.listdir() """
28 return dir_map.get(dir, [])
29
30
31def mock_exists(file_map, fn):
32 """ mock os.path.exists() """
33 return (fn in file_map)
34
35
36def mock_open(orig_open, file_map, fn, *args, **kwargs):
37 """ mock file open() """
38 if fn in file_map:
39 return StringIO(file_map[fn])
40 else:
41 return orig_open(fn, *args, **kwargs)
42
43
44class ItemList(object):
45 """ mock the various items within a LoadModelsRequest protobuf """
46
47 def __init__(self):
48 self.items = []
49
50 def add(self):
51 item = Mock()
52 self.items.append(item)
53 return item
54
55
56class MockLoadModelsRequest(object):
57 """ mock a LoadModelsRequest protobuf """
58
59 def __init__(self, *args, **kwargs):
60 for (k, v) in kwargs.items():
61 setattr(self, k, v)
62 self.xprotos = ItemList()
63 self.decls = ItemList()
64 self.attics = ItemList()
65 self.convenience_methods = ItemList()
66 self.migrations = ItemList()
67
68
69class TestLoadModels(unittest.TestCase):
70 def setUp(self):
71 self.sys_path_save = sys.path
72 self.cwd_save = os.getcwd()
73
74 config = os.path.join(test_path, "test_config.yaml")
75 from xosconfig import Config
76
77 Config.clear()
78 Config.init(config, "synchronizer-config-schema.yaml")
79
80 from xossynchronizer import loadmodels
81 from xossynchronizer.loadmodels import ModelLoadClient
82 self.loadmodels = loadmodels
83
84 self.api = MagicMock()
85 self.api.dynamicload_pb2.LoadModelsRequest = MockLoadModelsRequest
86 self.loader = ModelLoadClient(self.api)
87
88 def tearDown(self):
89 sys.path = self.sys_path_save
90 os.chdir(self.cwd_save)
91
92 def test_upload_models(self):
93 dir_map = {"models_dir": ["models.xproto", "models.py"],
94 "models_dir/convenience": ["convenience1.py"],
95 "models_dir/../migrations": ["migration1.py", "migration2.py"]}
96
97 file_map = {"models_dir/models.xproto": u"some xproto",
98 "models_dir/models.py": u"print `python models file`",
99 "models_dir/convenience": u"directory",
100 "models_dir/convenience/convenience1.py": u"print `python convenience file`",
101 "models_dir/../migrations": u"directory",
102 "models_dir/../migrations/migration1.py": u"print `first migration`",
103 "models_dir/../migrations/migration2.py": u"print `second migration`"}
104
105 orig_open = open
106 with patch("os.listdir", side_effect=functools.partial(mock_listdir, dir_map)), \
107 patch("os.path.exists", side_effect=functools.partial(mock_exists, file_map)), \
108 patch("__builtin__.open", side_effect=functools.partial(mock_open, orig_open, file_map)):
109 self.loader.upload_models("myservice", "models_dir", "1.2")
110
111 request = self.api.dynamicload.LoadModels.call_args[0][0]
112 self.assertEqual(request.name, "myservice")
113 self.assertEqual(request.version, "1.2")
114
115 self.assertEqual(len(request.xprotos.items), 1)
116 self.assertEqual(request.xprotos.items[0].filename, "models.xproto")
117 self.assertEqual(request.xprotos.items[0].contents, u"some xproto")
118
119 self.assertEqual(len(request.decls.items), 1)
120 self.assertEqual(request.decls.items[0].filename, "models.py")
121 self.assertEqual(request.decls.items[0].contents, u"print `python models file`")
122
123 self.assertEqual(len(request.convenience_methods.items), 1)
124 self.assertEqual(request.convenience_methods.items[0].filename, "convenience1.py")
125 self.assertEqual(request.convenience_methods.items[0].contents, u"print `python convenience file`")
126
127 self.assertEqual(len(request.migrations.items), 2)
128 self.assertEqual(request.migrations.items[0].filename, "migration1.py")
129 self.assertEqual(request.migrations.items[0].contents, u"print `first migration`")
130 self.assertEqual(request.migrations.items[1].filename, "migration2.py")
131 self.assertEqual(request.migrations.items[1].contents, u"print `second migration`")
132
133
134if __name__ == "__main__":
135 unittest.main()