blob: 61c2f1f7c1a3ed37ffcb1a743d11f206f2379494 [file] [log] [blame]
# Copyright 2017-present Open Networking Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
import random
def mock_enumerator(items):
e = lambda:None
e.all = lambda:items
return e
# A list of all mock object stores that have been created
AllMockObjectStores = []
class MockObjectList:
item_list = None
def __init__(self, initial=None):
self.id_counter = 0
if initial:
self.item_list=initial
elif self.item_list is None:
self.item_list=[]
def get_items(self):
return self.item_list
def count(self):
return len(self.get_items())
def first(self):
return self.get_items()[0]
def all(self):
return self.get_items()
def filter(self, **kwargs):
items = self.get_items()
for (k,v) in kwargs.items():
items = [x for x in items if getattr(x,k) == v]
return items
def get(self, **kwargs):
objs = self.filter(**kwargs)
if not objs:
raise IndexError("No objects matching %s" % str(kwargs))
return objs[0]
class MockObjectStore(MockObjectList):
def save(self, o):
if (not hasattr(o,"id")) or (not o.id) or (o.id==98052):
for item in self.get_items():
if item.id >= self.id_counter:
self.id_counter = item.id + 1
o.id = self.id_counter
self.id_counter = self.id_counter + 1
for item in self.get_items():
if item.id == o.id:
item = o
break
else:
self.get_items().append(o)
class MockObject(object):
objects = None
id = None
deleted = False
field_names = []
def __init__(self, **kwargs):
object.__setattr__(self, 'is_set', {})
setattr(self, 'backend_code', 0)
setattr(self, 'id', 98052)
setattr(self, 'pk', random.randint(0, 1<<30))
self.leaf_model = self
# reset is_set
self.is_set = {}
for (k,v) in kwargs.items():
setattr(self,k,v)
self.is_new = True
self._initial = self._dict
def __setattr__(self, name, value):
self.is_set[name] = True
object.__setattr__(self, name, value)
@property
def self_content_type_id(self):
return self.__class__.__name__
def save(self, update_fields=[], always_update_timestamp=False):
if self.objects:
self.objects.save(self)
def delete(self):
pass
def tologdict(self):
return {}
@property
def _dict(self):
d={}
for name in self.field_names:
if self.is_set.get(name, False):
d[name] = getattr(self, name)
return d
@property
def diff(self):
d1 = self._initial
d2 = self._dict
all_field_names = self.field_names
diffs=[]
for k in all_field_names:
if (d1.get(k,None) != d2.get(k,None)):
diffs.append( (k, (d1.get(k,None), d2.get(k,None))) )
return dict(diffs)
@property
def has_changed(self):
return bool(self.diff)
@property
def changed_fields(self):
if self.is_new:
return self._dict.keys()
return self.diff.keys()
def has_field_changed(self, field_name):
return field_name in self.diff.keys()
def get_field_diff(self, field_name):
return self.diff.get(field_name, None)
def recompute_initial(self):
self._initial = self._dict
def save_changed_fields(self, always_update_timestamp=False):
if self.has_changed:
update_fields = self.changed_fields
if always_update_timestamp and "updated" not in update_fields:
update_fields.append("updated")
self.save(update_fields=sorted(update_fields), always_update_timestamp=always_update_timestamp)
def get_MockObjectStore(x):
store = globals()["Mock%sObjects" % x]()
if not store in AllMockObjectStores:
AllMockObjectStores.append(store)
return store
class ModelAccessor(object):
def check_db_connection_okay(self):
return True
def connection_close(self):
pass
def journal_object(self, *args, **kwargs):
pass
def obj_exists(self, o):
# gRPC will default id to '0' for uninitialized objects
return (o.id is not None) and (o.id != 0)
def fetch_pending(self, model, deleted = False):
num = random.randint(1, 5)
object_list = []
for i in range(num):
if isinstance(model, list):
model = model[0]
try:
obj = model()
except:
import pdb
pdb.set_trace()
obj.name = "Opinionated Berry %d"%i
object_list.append(obj)
return object_list
def reset_all_object_stores(self):
for store in AllMockObjectStores:
store.items = []
def get_model_class(self, classname):
return globals()[classname]
def has_model_class(self, classname):
return classname in globals()
def __getattr__(self, name):
""" Wrapper for getattr to facilitate retrieval of classes """
has_model_class = self.__getattribute__("has_model_class")
get_model_class = self.__getattribute__("get_model_class")
if has_model_class(name):
return get_model_class(name)
# Default behaviour
return self.__getattribute__(name)
model_accessor = ModelAccessor()
class ObjectSet(object):
def __init__(self, objects):
self.objects = objects
def all(self):
return self.objects
#####
# DO NOT MODIFY THE CLASSES BELOW. THEY ARE AUTOGENERATED.
#
{% for m in proto.messages -%}
class Mock{{ m.name }}Objects(MockObjectStore): pass
class {{ m.name }}(MockObject):
objects = get_MockObjectStore("{{ m.name }}")
{% for f in xproto_base_fields(m, proto.message_table) + m.fields -%}
{{ f.name }} = {{ xproto_first_non_empty([f.options.default, "None"]) }}
{% if f.link -%}{{ f.name }}_id = None{% endif %}
{% endfor %}
leaf_model_name = "{{ m.name }}"
field_names = ["id", \
{% for f in xproto_base_fields(m, proto.message_table) + m.fields -%}
"{{ f.name }}",
{% endfor %}
]
{% endfor %}