blob: 01632d34dca45b16920d22f14995aea8e93b3b9f [file] [log] [blame]
Wei-Yu Chen49950b92021-11-08 19:19:18 +08001"""
2Copyright 2020 The Magma Authors.
3
4This source code is licensed under the BSD-style license found in the
5LICENSE file in the root directory of this source tree.
6
7Unless required by applicable law or agreed to in writing, software
8distributed under the License is distributed on an "AS IS" BASIS,
9WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10See the License for the specific language governing permissions and
11limitations under the License.
12"""
13
14import abc
15import logging
16import threading
17import time
18from typing import Any, List
19
20import grpc
21import snowflake
22from google.protobuf import any_pb2
23from common import serialization_utils
24from common.metrics import STREAMER_RESPONSES
25from common.service_registry import ServiceRegistry
26from configuration.service_configs import get_service_config_value
27from orc8r.protos.streamer_pb2 import DataUpdate, StreamRequest
28from orc8r.protos.streamer_pb2_grpc import StreamerStub
29
30
31class StreamerClient(threading.Thread):
32 """
33 StreamerClient provides an interface to communicate with the Streamer
34 service in the cloud to get updates for a stream.
35
36 The StreamerClient spawns a thread which listens to updates and
37 schedules a callback in the asyncio event loop when an update
38 is received from the cloud.
39
40 If the connection to the cloud gets terminated, the StreamerClient
41 would retry (TBD: with exponential backoff) to connect back to the cloud.
42 """
43
44 class Callback:
45
46 @abc.abstractmethod
47 def get_request_args(self, stream_name: str) -> Any:
48 """
49 This is called before every stream request to collect any extra
50 arguments to send up to the cloud streamer service.
51
52 Args:
53 stream_name:
54 Name of the stream that the request arg will be sent to
55
56 Returns: A protobuf message
57 """
58 pass
59
60 @abc.abstractmethod
61 def process_update(
62 self, stream_name: str, updates: List[DataUpdate],
63 resync: bool,
64 ):
65 """
66 Called when we get an update from the cloud. This method will
67 be called in the event loop provided to the StreamerClient.
68
69 Args:
70 stream_name: Name of the stream
71 updates: Array of updates
72 resync: if true, the application can clear the
73 contents before applying the updates
74 """
75 raise NotImplementedError()
76
77 def __init__(self, stream_callbacks, loop):
78 """
79 Args:
80 stream_callbacks ({string: Callback}): Mapping of stream names to
81 callbacks to subscribe to.
82 loop: asyncio event loop to schedule the callback
83 """
84 threading.Thread.__init__(self)
85 self._stream_callbacks = stream_callbacks
86 self._loop = loop
87 # Set this thread as daemon thread. We can kill this background
88 # thread abruptly since we handle all updates (and database
89 # transactions) in the asyncio event loop.
90 self.daemon = True
91
92 # Don't allow stream update rate faster than every 5 seconds
93 self._reconnect_pause = get_service_config_value(
94 'streamer', 'reconnect_sec', 60,
95 )
96 self._reconnect_pause = max(5, self._reconnect_pause)
97 logging.info("Streamer reconnect pause: %d", self._reconnect_pause)
98 self._stream_timeout = get_service_config_value(
99 'streamer', 'stream_timeout', 150,
100 )
101 logging.info("Streamer timeout: %d", self._stream_timeout)
102
103 def run(self):
104 while True:
105 try:
106 channel = ServiceRegistry.get_rpc_channel(
107 'streamer', ServiceRegistry.CLOUD,
108 )
109 client = StreamerStub(channel)
110 self.process_all_streams(client)
111 except Exception as exp: # pylint: disable=broad-except
112 logging.error("Error with streamer: %s", exp)
113
114 # If the connection is terminated, wait for a period of time
115 # before connecting back to the cloud.
116 # TODO: make this more intelligent (exponential backoffs, etc.)
117 time.sleep(self._reconnect_pause)
118
119 def process_all_streams(self, client):
120 for stream_name, callback in self._stream_callbacks.items():
121 try:
122 self.process_stream_updates(client, stream_name, callback)
123
124 STREAMER_RESPONSES.labels(result='Success').inc()
125 except grpc.RpcError as err:
126 logging.error(
127 "Error! Streaming from the cloud failed! [%s] %s",
128 err.code(), err.details(),
129 )
130 STREAMER_RESPONSES.labels(result='RpcError').inc()
131 except ValueError as err:
132 logging.error("Error! Streaming from cloud failed! %s", err)
133 STREAMER_RESPONSES.labels(result='ValueError').inc()
134
135 def process_stream_updates(self, client, stream_name, callback):
136 extra_args = self._get_extra_args_any(callback, stream_name)
137 request = StreamRequest(
138 gatewayId=snowflake.snowflake(),
139 stream_name=stream_name,
140 extra_args=extra_args,
141 )
142 for update_batch in client.GetUpdates(
143 request, timeout=self._stream_timeout,
144 ):
145 self._loop.call_soon_threadsafe(
146 callback.process_update,
147 stream_name,
148 update_batch.updates,
149 update_batch.resync,
150 )
151
152 @staticmethod
153 def _get_extra_args_any(callback, stream_name):
154 extra_args = callback.get_request_args(stream_name)
155 if extra_args is None:
156 return None
157 else:
158 extra_any = any_pb2.Any()
159 extra_any.Pack(extra_args)
160 return extra_any
161
162
163def get_stream_serialize_filename(stream_name):
164 return '/var/opt/magma/streams/{}'.format(stream_name)
165
166
167class SerializingStreamCallback(StreamerClient.Callback):
168 """
169 Streamer client callback which decodes stream update as a string and writes
170 it to a file, overwriting the previous contents of that file. The file
171 location is defined by get_stream_serialize_filename.
172
173 This callback will only save the newest update, with each successive update
174 overwriting the previous.
175 """
176
177 def get_request_args(self, stream_name: str) -> Any:
178 return None
179
180 def process_update(self, stream_name, updates, resync):
181 if not updates:
182 return
183 # For now, we only care about the last (newest) update
184 for update in updates[:-1]:
185 logging.info('Ignoring update %s', update.key)
186
187 logging.info('Serializing stream update %s', updates[-1].key)
188 filename = get_stream_serialize_filename(stream_name)
189 serialization_utils.write_to_file_atomically(
190 filename,
191 updates[-1].value.decode(),
192 )