blob: c227e4d7dcb103702b8886bb737ec5e6fb1f754d [file] [log] [blame]
"""
Copyright 2020 The Magma Authors.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from copy import deepcopy
from typing import Any, Iterator, List, MutableMapping, Optional, TypeVar
import redis
import redis_collections
import redis_lock
from common.redis.serializers import RedisSerde
from orc8r.protos.redis_pb2 import RedisState
from redis.lock import Lock
# NOTE: these containers replace the serialization methods exposed by
# the redis-collection objects. Although the methods are hinted to be
# privately scoped, the method replacement is encouraged in the library's
# docs: http://redis-collections.readthedocs.io/en/stable/usage-notes.html
T = TypeVar('T')
class RedisList(redis_collections.List):
"""
List-like interface serializing elements to a Redis datastore.
Notes:
- Provides persistence across sessions
- Mutable elements handled correctly
- Not expected to be thread safe, but could be extended
"""
def __init__(self, client, key, serialize, deserialize):
"""
Initialize instance.
Args:
client (redis.Redis): Redis client object
key (str): key where this container's elements are stored in Redis
serialize (function (any) -> bytes):
function called to serialize an element
deserialize (function (bytes) -> any):
function called to deserialize an element
Returns:
redis_list (redis_collections.List): persistent list-like interface
"""
self._pickle = serialize
self._unpickle = deserialize
super().__init__(redis=client, key=key, writeback=True)
def __copy__(self):
return [elt for elt in self]
def __deepcopy__(self, memo):
return [deepcopy(elt, memo) for elt in self]
class RedisSet(redis_collections.Set):
"""
Set-like interface serializing elements to a Redis datastore.
Notes:
- Provides persistence across sessions
- Mutable elements _not_ handled correctly:
- Get/set mutable elements supported
- Don't update the contents of a mutable element and
expect things to go well
- Expected to be thread safe, but not tested
"""
def __init__(self, client, key, serialize, deserialize):
"""
Initialize instance.
Args:
client (redis.Redis): Redis client object
key (str): key where this container's elements are stored in Redis
serialize (function (any) -> bytes):
function called to serialize an element
deserialize (function (bytes) -> any):
function called to deserialize an element
Returns:
redis_set (redis_collections.Set): persistent set-like interface
"""
# NOTE: redis_collections.Set doesn't have a writeback option, causing
# issue when mutable elements are updated in-place.
self._pickle = serialize
self._unpickle = deserialize
super().__init__(redis=client, key=key)
def __copy__(self):
return {elt for elt in self}
def __deepcopy__(self, memo):
return {deepcopy(elt, memo) for elt in self}
class RedisHashDict(redis_collections.DefaultDict):
"""
Dict-like interface serializing elements to a Redis datastore. This dict
utilizes Redis's hashmap functionality
Notes:
- Keys must be string-like and are serialized to plaintext (UTF-8)
- Provides persistence across sessions
- Mutable elements handled correctly
- Not expected to be thread safe, but could be extended
- Keys are serialized in plaintext
"""
@staticmethod
def serialize_key(key):
""" Serialize key to plaintext. """
return key
@staticmethod
def deserialize_key(serialized):
""" Deserialize key from plaintext encoded as UTF-8 bytes. """
return serialized.decode('utf-8') # Redis returns bytes
def __init__(
self, client, key, serialize, deserialize,
default_factory=None, writeback=False,
):
"""
Initialize instance.
Args:
client (redis.Redis): Redis client object
key (str): key where this container's elements are stored in Redis
serialize (function (any) -> bytes):
function called to serialize a value
deserialize (function (bytes) -> any):
function called to deserialize a value
default_factory: function that provides default value for a
non-existent key
writeback (bool): if writeback is set to true, dict maintains a
local cache of values and the `sync` method can be called to
store these values. NOTE: only use this option if syncing
between services is not important.
Returns:
redis_dict (redis_collections.Dict): persistent dict-like interface
"""
# Key serialization (to/from plaintext)
self._pickle_key = RedisHashDict.serialize_key
self._unpickle_key = RedisHashDict.deserialize_key
# Value serialization
self._pickle_value = serialize
self._unpickle = deserialize
super().__init__(
default_factory, redis=client, key=key, writeback=writeback,
)
def __setitem__(self, key, value):
"""Set ``d[key]`` to *value*.
Override in order to increment version on each update
"""
version = self.get_version(key)
pickled_key = self._pickle_key(key)
pickled_value = self._pickle_value(value, version + 1)
self.redis.hset(self.key, pickled_key, pickled_value)
if self.writeback:
self.cache[key] = value
def __copy__(self):
return {key: self[key] for key in self}
def __deepcopy__(self, memo):
return {key: deepcopy(self[key], memo) for key in self}
def get_version(self, key):
"""Return the version of the value for key *key*. Returns 0 if
key is not in the map
"""
try:
value = self.cache[key]
except KeyError:
pickled_key = self._pickle_key(key)
value = self.redis.hget(self.key, pickled_key)
if value is None:
return 0
proto_wrapper = RedisState()
proto_wrapper.ParseFromString(value)
return proto_wrapper.version
class RedisFlatDict(MutableMapping[str, T]):
"""
Dict-like interface serializing elements to a Redis datastore. This
dict stores key directly (i.e. without a hashmap).
"""
def __init__(
self, client: redis.Redis, serde: RedisSerde[T],
writethrough: bool = False,
):
"""
Args:
client (redis.Redis): Redis client object
serde (): RedisSerde for de/serializing the object stored
writethrough (bool): if writethrough is set to true,
RedisFlatDict maintains a local write-through cache of values.
"""
super().__init__()
self._writethrough = writethrough
self.redis = client
self.serde = serde
self.redis_type = serde.redis_type
self.cache = {}
if self._writethrough:
self._sync_cache()
def __len__(self) -> int:
"""Return the number of items in the dictionary."""
if self._writethrough:
return len(self.cache)
return len(self.keys())
def __iter__(self) -> Iterator[str]:
"""Return an iterator over the keys of the dictionary."""
type_pattern = self._get_redis_type_pattern()
if self._writethrough:
for k in self.cache:
split_key, _ = k.split(":", 1)
yield split_key
else:
for k in self.redis.keys(pattern=type_pattern):
try:
deserialized_key = k.decode('utf-8')
split_key = deserialized_key.split(":", 1)
except AttributeError:
split_key = k.split(":", 1)
# There could be a delete key in between KEYS and GET, so ignore
# invalid values for now
try:
if self.is_garbage(split_key[0]):
continue
except KeyError:
continue
yield split_key[0]
def __contains__(self, key: str) -> bool:
"""Return ``True`` if *key* is present and not garbage,
else ``False``.
"""
composite_key = self._make_composite_key(key)
if self._writethrough:
return composite_key in self.cache
return bool(self.redis.exists(composite_key)) and \
not self.is_garbage(key)
def __getitem__(self, key: str) -> T:
"""Return the item of dictionary with key *key:type*. Raises a
:exc:`KeyError` if *key:type* is not in the map or the object is
garbage
"""
if ':' in key:
raise ValueError("Key %s cannot contain ':' char" % key)
composite_key = self._make_composite_key(key)
if self._writethrough:
cached_value = self.cache.get(composite_key)
if cached_value:
return cached_value
serialized_value = self.redis.get(composite_key)
if serialized_value is None:
raise KeyError(composite_key)
proto_wrapper = RedisState()
proto_wrapper.ParseFromString(serialized_value)
if proto_wrapper.is_garbage:
raise KeyError("Key %s is garbage" % key)
return self.serde.deserialize(serialized_value)
def __setitem__(self, key: str, value: T) -> Any:
"""Set ``d[key:type]`` to *value*."""
if ':' in key:
raise ValueError("Key %s cannot contain ':' char" % key)
version = self.get_version(key)
serialized_value = self.serde.serialize(value, version + 1)
composite_key = self._make_composite_key(key)
if self._writethrough:
self.cache[composite_key] = value
return self.redis.set(composite_key, serialized_value)
def __delitem__(self, key: str) -> int:
"""Remove ``d[key:type]`` from dictionary.
Raises a :func:`KeyError` if *key:type* is not in the map.
"""
if ':' in key:
raise ValueError("Key %s cannot contain ':' char" % key)
composite_key = self._make_composite_key(key)
if self._writethrough:
del self.cache[composite_key]
deleted_count = self.redis.delete(composite_key)
if not deleted_count:
raise KeyError(composite_key)
return deleted_count
def get(self, key: str, default=None) -> Optional[T]:
"""Get ``d[key:type]`` from dictionary.
Returns None if *key:type* is not in the map
"""
try:
return self.__getitem__(key)
except (KeyError, ValueError):
return default
def clear(self) -> None:
"""
Clear all keys in the dictionary. Objects are immediately deleted
(i.e. not garbage collected)
"""
if self._writethrough:
self.cache.clear()
for key in self.keys():
composite_key = self._make_composite_key(key)
self.redis.delete(composite_key)
def get_version(self, key: str) -> int:
"""Return the version of the value for key *key:type*. Returns 0 if
key is not in the map
"""
composite_key = self._make_composite_key(key)
value = self.redis.get(composite_key)
if value is None:
return 0
proto_wrapper = RedisState()
proto_wrapper.ParseFromString(value)
return proto_wrapper.version
def keys(self) -> List[str]:
"""Return a copy of the dictionary's list of keys
Note: for redis *key:type* key is returned
"""
if self._writethrough:
return list(self.cache.keys())
return list(self.__iter__())
def mark_as_garbage(self, key: str) -> Any:
"""Mark ``d[key:type]`` for garbage collection
Raises a KeyError if *key:type* is not in the map.
"""
composite_key = self._make_composite_key(key)
value = self.redis.get(composite_key)
if value is None:
raise KeyError(composite_key)
proto_wrapper = RedisState()
proto_wrapper.ParseFromString(value)
proto_wrapper.is_garbage = True
garbage_serialized = proto_wrapper.SerializeToString()
return self.redis.set(composite_key, garbage_serialized)
def is_garbage(self, key: str) -> bool:
"""Return if d[key:type] has been marked for garbage collection.
Raises a KeyError if *key:type* is not in the map.
"""
composite_key = self._make_composite_key(key)
value = self.redis.get(composite_key)
if value is None:
raise KeyError(composite_key)
proto_wrapper = RedisState()
proto_wrapper.ParseFromString(value)
return proto_wrapper.is_garbage
def garbage_keys(self) -> List[str]:
"""Return a copy of the dictionary's list of keys that are garbage
Note: for redis *key:type* key is returned
"""
garbage_keys = []
type_pattern = self._get_redis_type_pattern()
for k in self.redis.keys(pattern=type_pattern):
try:
deserialized_key = k.decode('utf-8')
split_key = deserialized_key.split(":", 1)
except AttributeError:
split_key = k.split(":", 1)
# There could be a delete key in between KEYS and GET, so ignore
# invalid values for now
try:
if not self.is_garbage(split_key[0]):
continue
except KeyError:
continue
garbage_keys.append(split_key[0])
return garbage_keys
def delete_garbage(self, key) -> bool:
"""Remove ``d[key:type]`` from dictionary iff the object is garbage
Returns False if *key:type* is not in the map
"""
if not self.is_garbage(key):
return False
count = self.__delitem__(key)
return count > 0
def lock(self, key: str) -> Lock:
"""Lock the dictionary for key *key*"""
return redis_lock.Lock(
self.redis,
name=self._make_composite_key(key) + ":lock",
expire=60,
auto_renewal=True,
strict=False,
)
def _sync_cache(self):
"""
Syncs write-through cache with redis data on store.
"""
type_pattern = self._get_redis_type_pattern()
for k in self.redis.keys(pattern=type_pattern):
composite_key = k.decode('utf-8')
serialized_value = self.redis.get(composite_key)
value = self.serde.deserialize(serialized_value)
self.cache[composite_key] = value
def _get_redis_type_pattern(self):
return "*:" + self.redis_type
def _make_composite_key(self, key):
return key + ":" + self.redis_type