blob: c227e4d7dcb103702b8886bb737ec5e6fb1f754d [file] [log] [blame]
Wei-Yu Chen49950b92021-11-08 19:19:18 +08001"""
2Copyright 2020 The Magma Authors.
3
4This source code is licensed under the BSD-style license found in the
5LICENSE file in the root directory of this source tree.
6
7Unless required by applicable law or agreed to in writing, software
8distributed under the License is distributed on an "AS IS" BASIS,
9WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10See the License for the specific language governing permissions and
11limitations under the License.
12"""
13from copy import deepcopy
14from typing import Any, Iterator, List, MutableMapping, Optional, TypeVar
15
16import redis
17import redis_collections
18import redis_lock
19from common.redis.serializers import RedisSerde
20from orc8r.protos.redis_pb2 import RedisState
21from redis.lock import Lock
22
23# NOTE: these containers replace the serialization methods exposed by
24# the redis-collection objects. Although the methods are hinted to be
25# privately scoped, the method replacement is encouraged in the library's
26# docs: http://redis-collections.readthedocs.io/en/stable/usage-notes.html
27
28T = TypeVar('T')
29
30
31class RedisList(redis_collections.List):
32 """
33 List-like interface serializing elements to a Redis datastore.
34
35 Notes:
36 - Provides persistence across sessions
37 - Mutable elements handled correctly
38 - Not expected to be thread safe, but could be extended
39 """
40
41 def __init__(self, client, key, serialize, deserialize):
42 """
43 Initialize instance.
44
45 Args:
46 client (redis.Redis): Redis client object
47 key (str): key where this container's elements are stored in Redis
48 serialize (function (any) -> bytes):
49 function called to serialize an element
50 deserialize (function (bytes) -> any):
51 function called to deserialize an element
52 Returns:
53 redis_list (redis_collections.List): persistent list-like interface
54 """
55 self._pickle = serialize
56 self._unpickle = deserialize
57 super().__init__(redis=client, key=key, writeback=True)
58
59 def __copy__(self):
60 return [elt for elt in self]
61
62 def __deepcopy__(self, memo):
63 return [deepcopy(elt, memo) for elt in self]
64
65
66class RedisSet(redis_collections.Set):
67 """
68 Set-like interface serializing elements to a Redis datastore.
69
70 Notes:
71 - Provides persistence across sessions
72 - Mutable elements _not_ handled correctly:
73 - Get/set mutable elements supported
74 - Don't update the contents of a mutable element and
75 expect things to go well
76 - Expected to be thread safe, but not tested
77 """
78
79 def __init__(self, client, key, serialize, deserialize):
80 """
81 Initialize instance.
82
83 Args:
84 client (redis.Redis): Redis client object
85 key (str): key where this container's elements are stored in Redis
86 serialize (function (any) -> bytes):
87 function called to serialize an element
88 deserialize (function (bytes) -> any):
89 function called to deserialize an element
90 Returns:
91 redis_set (redis_collections.Set): persistent set-like interface
92 """
93 # NOTE: redis_collections.Set doesn't have a writeback option, causing
94 # issue when mutable elements are updated in-place.
95 self._pickle = serialize
96 self._unpickle = deserialize
97 super().__init__(redis=client, key=key)
98
99 def __copy__(self):
100 return {elt for elt in self}
101
102 def __deepcopy__(self, memo):
103 return {deepcopy(elt, memo) for elt in self}
104
105
106class RedisHashDict(redis_collections.DefaultDict):
107 """
108 Dict-like interface serializing elements to a Redis datastore. This dict
109 utilizes Redis's hashmap functionality
110
111 Notes:
112 - Keys must be string-like and are serialized to plaintext (UTF-8)
113 - Provides persistence across sessions
114 - Mutable elements handled correctly
115 - Not expected to be thread safe, but could be extended
116 - Keys are serialized in plaintext
117 """
118
119 @staticmethod
120 def serialize_key(key):
121 """ Serialize key to plaintext. """
122 return key
123
124 @staticmethod
125 def deserialize_key(serialized):
126 """ Deserialize key from plaintext encoded as UTF-8 bytes. """
127 return serialized.decode('utf-8') # Redis returns bytes
128
129 def __init__(
130 self, client, key, serialize, deserialize,
131 default_factory=None, writeback=False,
132 ):
133 """
134 Initialize instance.
135
136 Args:
137 client (redis.Redis): Redis client object
138 key (str): key where this container's elements are stored in Redis
139 serialize (function (any) -> bytes):
140 function called to serialize a value
141 deserialize (function (bytes) -> any):
142 function called to deserialize a value
143 default_factory: function that provides default value for a
144 non-existent key
145 writeback (bool): if writeback is set to true, dict maintains a
146 local cache of values and the `sync` method can be called to
147 store these values. NOTE: only use this option if syncing
148 between services is not important.
149
150 Returns:
151 redis_dict (redis_collections.Dict): persistent dict-like interface
152 """
153 # Key serialization (to/from plaintext)
154 self._pickle_key = RedisHashDict.serialize_key
155 self._unpickle_key = RedisHashDict.deserialize_key
156 # Value serialization
157 self._pickle_value = serialize
158 self._unpickle = deserialize
159 super().__init__(
160 default_factory, redis=client, key=key, writeback=writeback,
161 )
162
163 def __setitem__(self, key, value):
164 """Set ``d[key]`` to *value*.
165
166 Override in order to increment version on each update
167 """
168 version = self.get_version(key)
169 pickled_key = self._pickle_key(key)
170 pickled_value = self._pickle_value(value, version + 1)
171 self.redis.hset(self.key, pickled_key, pickled_value)
172
173 if self.writeback:
174 self.cache[key] = value
175
176 def __copy__(self):
177 return {key: self[key] for key in self}
178
179 def __deepcopy__(self, memo):
180 return {key: deepcopy(self[key], memo) for key in self}
181
182 def get_version(self, key):
183 """Return the version of the value for key *key*. Returns 0 if
184 key is not in the map
185 """
186 try:
187 value = self.cache[key]
188 except KeyError:
189 pickled_key = self._pickle_key(key)
190 value = self.redis.hget(self.key, pickled_key)
191 if value is None:
192 return 0
193
194 proto_wrapper = RedisState()
195 proto_wrapper.ParseFromString(value)
196 return proto_wrapper.version
197
198
199class RedisFlatDict(MutableMapping[str, T]):
200 """
201 Dict-like interface serializing elements to a Redis datastore. This
202 dict stores key directly (i.e. without a hashmap).
203 """
204
205 def __init__(
206 self, client: redis.Redis, serde: RedisSerde[T],
207 writethrough: bool = False,
208 ):
209 """
210 Args:
211 client (redis.Redis): Redis client object
212 serde (): RedisSerde for de/serializing the object stored
213 writethrough (bool): if writethrough is set to true,
214 RedisFlatDict maintains a local write-through cache of values.
215 """
216 super().__init__()
217 self._writethrough = writethrough
218 self.redis = client
219 self.serde = serde
220 self.redis_type = serde.redis_type
221 self.cache = {}
222 if self._writethrough:
223 self._sync_cache()
224
225 def __len__(self) -> int:
226 """Return the number of items in the dictionary."""
227 if self._writethrough:
228 return len(self.cache)
229
230 return len(self.keys())
231
232 def __iter__(self) -> Iterator[str]:
233 """Return an iterator over the keys of the dictionary."""
234 type_pattern = self._get_redis_type_pattern()
235
236 if self._writethrough:
237 for k in self.cache:
238 split_key, _ = k.split(":", 1)
239 yield split_key
240 else:
241 for k in self.redis.keys(pattern=type_pattern):
242 try:
243 deserialized_key = k.decode('utf-8')
244 split_key = deserialized_key.split(":", 1)
245 except AttributeError:
246 split_key = k.split(":", 1)
247 # There could be a delete key in between KEYS and GET, so ignore
248 # invalid values for now
249 try:
250 if self.is_garbage(split_key[0]):
251 continue
252 except KeyError:
253 continue
254 yield split_key[0]
255
256 def __contains__(self, key: str) -> bool:
257 """Return ``True`` if *key* is present and not garbage,
258 else ``False``.
259 """
260 composite_key = self._make_composite_key(key)
261
262 if self._writethrough:
263 return composite_key in self.cache
264
265 return bool(self.redis.exists(composite_key)) and \
266 not self.is_garbage(key)
267
268 def __getitem__(self, key: str) -> T:
269 """Return the item of dictionary with key *key:type*. Raises a
270 :exc:`KeyError` if *key:type* is not in the map or the object is
271 garbage
272 """
273 if ':' in key:
274 raise ValueError("Key %s cannot contain ':' char" % key)
275 composite_key = self._make_composite_key(key)
276
277 if self._writethrough:
278 cached_value = self.cache.get(composite_key)
279 if cached_value:
280 return cached_value
281
282 serialized_value = self.redis.get(composite_key)
283 if serialized_value is None:
284 raise KeyError(composite_key)
285
286 proto_wrapper = RedisState()
287 proto_wrapper.ParseFromString(serialized_value)
288 if proto_wrapper.is_garbage:
289 raise KeyError("Key %s is garbage" % key)
290
291 return self.serde.deserialize(serialized_value)
292
293 def __setitem__(self, key: str, value: T) -> Any:
294 """Set ``d[key:type]`` to *value*."""
295 if ':' in key:
296 raise ValueError("Key %s cannot contain ':' char" % key)
297 version = self.get_version(key)
298 serialized_value = self.serde.serialize(value, version + 1)
299 composite_key = self._make_composite_key(key)
300 if self._writethrough:
301 self.cache[composite_key] = value
302 return self.redis.set(composite_key, serialized_value)
303
304 def __delitem__(self, key: str) -> int:
305 """Remove ``d[key:type]`` from dictionary.
306 Raises a :func:`KeyError` if *key:type* is not in the map.
307 """
308 if ':' in key:
309 raise ValueError("Key %s cannot contain ':' char" % key)
310 composite_key = self._make_composite_key(key)
311 if self._writethrough:
312 del self.cache[composite_key]
313 deleted_count = self.redis.delete(composite_key)
314 if not deleted_count:
315 raise KeyError(composite_key)
316 return deleted_count
317
318 def get(self, key: str, default=None) -> Optional[T]:
319 """Get ``d[key:type]`` from dictionary.
320 Returns None if *key:type* is not in the map
321 """
322 try:
323 return self.__getitem__(key)
324 except (KeyError, ValueError):
325 return default
326
327 def clear(self) -> None:
328 """
329 Clear all keys in the dictionary. Objects are immediately deleted
330 (i.e. not garbage collected)
331 """
332 if self._writethrough:
333 self.cache.clear()
334 for key in self.keys():
335 composite_key = self._make_composite_key(key)
336 self.redis.delete(composite_key)
337
338 def get_version(self, key: str) -> int:
339 """Return the version of the value for key *key:type*. Returns 0 if
340 key is not in the map
341 """
342 composite_key = self._make_composite_key(key)
343 value = self.redis.get(composite_key)
344 if value is None:
345 return 0
346
347 proto_wrapper = RedisState()
348 proto_wrapper.ParseFromString(value)
349 return proto_wrapper.version
350
351 def keys(self) -> List[str]:
352 """Return a copy of the dictionary's list of keys
353 Note: for redis *key:type* key is returned
354 """
355 if self._writethrough:
356 return list(self.cache.keys())
357
358 return list(self.__iter__())
359
360 def mark_as_garbage(self, key: str) -> Any:
361 """Mark ``d[key:type]`` for garbage collection
362 Raises a KeyError if *key:type* is not in the map.
363 """
364 composite_key = self._make_composite_key(key)
365 value = self.redis.get(composite_key)
366 if value is None:
367 raise KeyError(composite_key)
368
369 proto_wrapper = RedisState()
370 proto_wrapper.ParseFromString(value)
371 proto_wrapper.is_garbage = True
372 garbage_serialized = proto_wrapper.SerializeToString()
373 return self.redis.set(composite_key, garbage_serialized)
374
375 def is_garbage(self, key: str) -> bool:
376 """Return if d[key:type] has been marked for garbage collection.
377 Raises a KeyError if *key:type* is not in the map.
378 """
379 composite_key = self._make_composite_key(key)
380 value = self.redis.get(composite_key)
381 if value is None:
382 raise KeyError(composite_key)
383
384 proto_wrapper = RedisState()
385 proto_wrapper.ParseFromString(value)
386 return proto_wrapper.is_garbage
387
388 def garbage_keys(self) -> List[str]:
389 """Return a copy of the dictionary's list of keys that are garbage
390 Note: for redis *key:type* key is returned
391 """
392 garbage_keys = []
393 type_pattern = self._get_redis_type_pattern()
394 for k in self.redis.keys(pattern=type_pattern):
395 try:
396 deserialized_key = k.decode('utf-8')
397 split_key = deserialized_key.split(":", 1)
398 except AttributeError:
399 split_key = k.split(":", 1)
400 # There could be a delete key in between KEYS and GET, so ignore
401 # invalid values for now
402 try:
403 if not self.is_garbage(split_key[0]):
404 continue
405 except KeyError:
406 continue
407 garbage_keys.append(split_key[0])
408 return garbage_keys
409
410 def delete_garbage(self, key) -> bool:
411 """Remove ``d[key:type]`` from dictionary iff the object is garbage
412 Returns False if *key:type* is not in the map
413 """
414 if not self.is_garbage(key):
415 return False
416 count = self.__delitem__(key)
417 return count > 0
418
419 def lock(self, key: str) -> Lock:
420 """Lock the dictionary for key *key*"""
421 return redis_lock.Lock(
422 self.redis,
423 name=self._make_composite_key(key) + ":lock",
424 expire=60,
425 auto_renewal=True,
426 strict=False,
427 )
428
429 def _sync_cache(self):
430 """
431 Syncs write-through cache with redis data on store.
432 """
433 type_pattern = self._get_redis_type_pattern()
434 for k in self.redis.keys(pattern=type_pattern):
435 composite_key = k.decode('utf-8')
436 serialized_value = self.redis.get(composite_key)
437 value = self.serde.deserialize(serialized_value)
438 self.cache[composite_key] = value
439
440 def _get_redis_type_pattern(self):
441 return "*:" + self.redis_type
442
443 def _make_composite_key(self, key):
444 return key + ":" + self.redis_type