Init commit for standalone enodebd
Change-Id: I88eeef5135dd7ba8551ddd9fb6a0695f5325337b
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..5c6cb64
--- /dev/null
+++ b/common/__init__.py
@@ -0,0 +1,12 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
diff --git a/common/cert_utils.py b/common/cert_utils.py
new file mode 100644
index 0000000..3be5284
--- /dev/null
+++ b/common/cert_utils.py
@@ -0,0 +1,160 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import base64
+
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.x509.oid import NameOID
+from common.serialization_utils import write_to_file_atomically
+
+
+def load_key(key_file):
+ """Load a private key encoded in PEM format
+
+ Args:
+ key_file: path to the key file
+
+ Returns:
+ RSAPrivateKey or EllipticCurvePrivateKey depending on the contents of key_file
+
+ Raises:
+ IOError: If file cannot be opened
+ ValueError: If the file content cannot be decoded successfully
+ TypeError: If the key_file is encrypted
+ """
+ with open(key_file, 'rb') as f:
+ key_bytes = f.read()
+ return serialization.load_pem_private_key(
+ key_bytes, None, default_backend(),
+ )
+
+
+def write_key(key, key_file):
+ """Write key object to file in PEM format atomically
+
+ Args:
+ key: RSAPrivateKey or EllipticCurvePrivateKey object
+ key_file: path to the key file
+ """
+ key_pem = key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ serialization.NoEncryption(),
+ )
+ write_to_file_atomically(key_file, key_pem.decode("utf-8"))
+
+
+def load_public_key_to_base64der(key_file):
+ """Load the public key of a private key and convert to base64 encoded DER
+ The return value can be used directly for device registration.
+
+ Args:
+ key_file: path to the private key file, pem encoded
+
+ Returns:
+ base64 encoded public key in DER format
+
+ Raises:
+ IOError: If file cannot be opened
+ ValueError: If the file content cannot be decoded successfully
+ TypeError: If the key_file is encrypted
+ """
+
+ key = load_key(key_file)
+ pub_key = key.public_key()
+ pub_bytes = pub_key.public_bytes(
+ encoding=serialization.Encoding.DER,
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
+ )
+ encoded = base64.b64encode(pub_bytes)
+ return encoded
+
+
+def create_csr(
+ key, common_name,
+ country=None, state=None, city=None, org=None,
+ org_unit=None, email_address=None,
+):
+ """Create csr and sign it with key.
+
+ Args:
+ key: RSAPrivateKey or EllipticCurvePrivateKey object
+ common_name: common name
+ country: country
+ state: state or province
+ city: city
+ org: organization
+ org_unit: organizational unit
+ email_address: email address
+
+ Returns:
+ csr: x509.CertificateSigningRequest
+ """
+ name_attrs = [x509.NameAttribute(NameOID.COMMON_NAME, common_name)]
+ if country:
+ name_attrs.append(x509.NameAttribute(NameOID.COUNTRY_NAME, country))
+ if state:
+ name_attrs.append(
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
+ )
+ if city:
+ name_attrs.append(x509.NameAttribute(NameOID.LOCALITY_NAME, city))
+ if org:
+ name_attrs.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, org))
+ if org_unit:
+ name_attrs.append(
+ x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, org_unit),
+ )
+ if email_address:
+ name_attrs.append(
+ x509.NameAttribute(NameOID.EMAIL_ADDRESS, email_address),
+ )
+
+ csr = x509.CertificateSigningRequestBuilder().subject_name(
+ x509.Name(name_attrs),
+ ).sign(key, hashes.SHA256(), default_backend())
+
+ return csr
+
+
+def load_cert(cert_file):
+ """Load certificate from a file
+
+ Args:
+ cert_file: path to file storing the cert in PEM format
+
+ Returns:
+ cert: an instance of x509.Certificate
+
+ Raises:
+ IOError: If file cannot be opened
+ ValueError: If the file content cannot be decoded successfully
+ """
+ with open(cert_file, 'rb') as f:
+ cert_pem = f.read()
+ cert = x509.load_pem_x509_certificate(cert_pem, default_backend())
+ return cert
+
+
+def write_cert(cert_der, cert_file):
+ """Write DER encoded cert to file in PEM format
+
+ Args:
+ cert_der: certificate encoded in DER format
+ cert_file: path to certificate
+ """
+ cert = x509.load_der_x509_certificate(cert_der, default_backend())
+ cert_pem = cert.public_bytes(serialization.Encoding.PEM)
+ write_to_file_atomically(cert_file, cert_pem.decode("utf-8"))
diff --git a/common/cert_validity.py b/common/cert_validity.py
new file mode 100644
index 0000000..83a765c
--- /dev/null
+++ b/common/cert_validity.py
@@ -0,0 +1,111 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Util module to distinguish between the reasons checkins stop working: network
+is down or cert is invalid.
+"""
+
+import asyncio
+import logging
+import os
+import ssl
+
+
+class TCPClientProtocol(asyncio.Protocol):
+ """
+ Implementation of TCP Protocol to create and immediately close the
+ connection
+ """
+
+ def connection_made(self, transport):
+ transport.close()
+
+
+@asyncio.coroutine
+def create_tcp_connection(host, port, loop):
+ """
+ Creates tcp connection
+ """
+ tcp_conn = yield from loop.create_connection(
+ TCPClientProtocol,
+ host,
+ port,
+ )
+ return tcp_conn
+
+
+@asyncio.coroutine
+def create_ssl_connection(host, port, certfile, keyfile, loop):
+ """
+ Creates ssl connection.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(
+ certfile,
+ keyfile=keyfile,
+ )
+
+ ssl_conn = yield from loop.create_connection(
+ TCPClientProtocol,
+ host,
+ port,
+ ssl=context,
+ )
+ return ssl_conn
+
+
+@asyncio.coroutine
+def cert_is_invalid(host, port, certfile, keyfile, loop):
+ """
+ Asynchronously test if both a TCP and SSL connection can be made to host
+ on port. If the TCP connection is successful, but the SSL connection fails,
+ we assume this is due to an invalid cert.
+
+ Args:
+ host: host to connect to
+ port: port to connect to on host
+ certfile: path to a PEM encoded certificate
+ keyfile: path to the corresponding key to the certificate
+ loop: asyncio event loop
+ Returns:
+ True if the cert is invalid
+ False otherwise
+ """
+ # Create connections
+ tcp_coro = create_tcp_connection(host, port, loop)
+ ssl_coro = create_ssl_connection(host, port, certfile, keyfile, loop)
+
+ coros = tcp_coro, ssl_coro
+ asyncio.set_event_loop(loop)
+ res = yield from asyncio.gather(*coros, return_exceptions=True)
+ tcp_res, ssl_res = res
+
+ if isinstance(tcp_res, Exception):
+ logging.error(
+ 'Error making TCP connection: %s, %s',
+ 'errno==None' if tcp_res.errno is None
+ else os.strerror(tcp_res.errno),
+ tcp_res,
+ )
+ return False
+
+ # Invalid cert only when tcp succeeds and ssl fails
+ if isinstance(ssl_res, Exception):
+ logging.error(
+ 'Error making SSL connection: %s, %s',
+ 'errno==None' if ssl_res.errno is None
+ else os.strerror(ssl_res.errno),
+ ssl_res,
+ )
+ return True
+
+ return False
diff --git a/common/grpc_client_manager.py b/common/grpc_client_manager.py
new file mode 100644
index 0000000..dbaff76
--- /dev/null
+++ b/common/grpc_client_manager.py
@@ -0,0 +1,77 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+import sys
+
+import grpc
+import psutil
+from common.service_registry import ServiceRegistry
+
+
+class GRPCClientManager:
+ def __init__(
+ self, service_name: str, service_stub,
+ max_client_reuse: int = 60,
+ ):
+ self._client = None
+ self._service_stub = service_stub
+ self._service_name = service_name
+ self._num_client_use = 0
+ self._max_client_reuse = max_client_reuse
+
+ def get_client(self):
+ """
+ get_client returns a grpc client of the specified service in the cloud.
+ it will return a recycled client until the client fails or the number
+ of recycling reaches the max_client_use.
+ """
+ if self._client is None or \
+ self._num_client_use > self._max_client_reuse:
+ chan = ServiceRegistry.get_rpc_channel(
+ self._service_name,
+ ServiceRegistry.CLOUD,
+ )
+ self._client = self._service_stub(chan)
+ self._num_client_use = 0
+
+ self._num_client_use += 1
+ return self._client
+
+ def on_grpc_fail(self, err_code):
+ """
+ Try to reuse the grpc client if possible. We are yet to fix a
+ grpc behavior, where if DNS request blackholes then the DNS request
+ is retried infinitely even after the channel is deleted. To prevent
+ running out of fds, we try to reuse the channel during such failures
+ as much as possible.
+ """
+ if err_code != grpc.StatusCode.DEADLINE_EXCEEDED:
+ # Not related to the DNS issue
+ self._reset_client()
+ if self._num_client_use >= self._max_client_reuse:
+ logging.info('Max client reuse reached. Cleaning up client')
+ self._reset_client()
+
+ # Sanity check if we are not leaking fds
+ proc = psutil.Process()
+ max_fds, _ = proc.rlimit(psutil.RLIMIT_NOFILE)
+ open_fds = proc.num_fds()
+ logging.info('Num open fds: %d', open_fds)
+ if open_fds >= (max_fds * 0.8):
+ logging.error("Reached 80% of allowed fds. Restarting process")
+ sys.exit(1)
+
+ def _reset_client(self):
+ self._client = None
+ self._num_client_use = 0
diff --git a/common/health/__init__.py b/common/health/__init__.py
new file mode 100644
index 0000000..5c6cb64
--- /dev/null
+++ b/common/health/__init__.py
@@ -0,0 +1,12 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
diff --git a/common/health/docker_health_service.py b/common/health/docker_health_service.py
new file mode 100644
index 0000000..70a728a
--- /dev/null
+++ b/common/health/docker_health_service.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from datetime import datetime
+
+import dateutil.parser
+import docker
+from common.health.entities import Errors, ServiceHealth, Version
+from common.health.health_service import GenericHealthChecker
+
+
+class DockerHealthChecker(GenericHealthChecker):
+
+ def get_error_summary(self, service_names):
+ res = {}
+ for service_name in service_names:
+ client = docker.from_env()
+ container = client.containers.get(service_name)
+
+ res[service_name] = Errors(log_level='-', error_count=0)
+ for line in container.logs().decode('utf-8').split('\n'):
+ if service_name not in line:
+ continue
+ # Reset the counter for restart/start
+ if 'Starting {}...'.format(service_name) in line:
+ res[service_name].error_count = 0
+ elif 'ERROR' in line:
+ res[service_name].error_count += 1
+ return res
+
+ def get_magma_services_summary(self):
+ services_health_summary = []
+ client = docker.from_env()
+
+ for container in client.containers.list():
+ service_start_time = dateutil.parser.parse(
+ container.attrs['State']['StartedAt'],
+ )
+ current_time = datetime.now(service_start_time.tzinfo)
+ time_running = current_time - service_start_time
+ services_health_summary.append(
+ ServiceHealth(
+ service_name=container.name,
+ active_state=container.status,
+ sub_state=container.status,
+ time_running=str(time_running).split('.', 1)[0],
+ errors=self.get_error_summary([container.name])[
+ container.name
+ ],
+ ),
+ )
+ return services_health_summary
+
+ def get_magma_version(self):
+ client = docker.from_env()
+ container = client.containers.get('magmad')
+
+ return Version(
+ version_code=container.attrs['Config']['Image'],
+ last_update_time='-',
+ )
diff --git a/common/health/entities.py b/common/health/entities.py
new file mode 100644
index 0000000..ab32496
--- /dev/null
+++ b/common/health/entities.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python3
+
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import textwrap
+
+
+class ActiveState:
+ ACTIVE = 'active'
+ RELOADING = 'reloading'
+ INACTIVE = 'inactive'
+ FAILED = 'failed'
+ ACTIVATING = 'activating'
+ DEACTIVATING = 'deactivating'
+
+ dbus2state = {
+ b'active': ACTIVE,
+ b'reloading': RELOADING,
+ b'inactive': INACTIVE,
+ b'failed': FAILED,
+ b'activating': ACTIVATING,
+ b'deactivating': DEACTIVATING,
+ }
+
+ state2status = {
+ ACTIVE: u'\u2714',
+ RELOADING: u'\u27A4',
+ INACTIVE: u'\u2717',
+ FAILED: u'\u2717',
+ ACTIVATING: u'\u27A4',
+ DEACTIVATING: u'\u27A4',
+ }
+
+
+class Errors:
+ def __init__(self, log_level, error_count):
+ self.log_level = log_level
+ self.error_count = error_count
+
+ def __str__(self):
+ return '{}: {}'.format(self.log_level, self.error_count)
+
+
+class RestartFrequency:
+ def __init__(self, count, time_interval):
+ self.count = count
+ self.time_interval = time_interval
+
+ def __str__(self):
+ return 'Restarted {} times {}'.format(
+ self.count,
+ self.time_interval,
+ )
+
+
+class HealthStatus:
+ DOWN = 'Down'
+ UP = 'Up'
+ UNKNOWN = 'Unknown'
+
+
+class Version:
+ def __init__(self, version_code, last_update_time):
+ self.version_code = version_code
+ self.last_update_time = last_update_time
+
+ def __str__(self):
+ return '{}, last updated: {}'.format(
+ self.version_code,
+ self.last_update_time,
+ )
+
+
+class ServiceHealth:
+ def __init__(
+ self, service_name, active_state, sub_state,
+ time_running, errors,
+ ):
+ self.service_name = service_name
+ self.active_state = active_state
+ self.sub_state = sub_state
+ self.time_running = time_running
+ self.errors = errors
+
+ def __str__(self):
+ return '{} {:20} {:10} {:15} {:10} {:>10} {:>10}'.format(
+ ActiveState.state2status.get(self.active_state, '-'),
+ self.service_name,
+ self.active_state,
+ self.sub_state,
+ self.time_running,
+ self.errors.log_level,
+ self.errors.error_count,
+ )
+
+
+class HealthSummary:
+ def __init__(
+ self, version, platform,
+ services_health,
+ internet_health, dns_health,
+ unexpected_restarts,
+ ):
+ self.version = version
+ self.platform = platform
+ self.services_health = services_health
+ self.internet_health = internet_health
+ self.dns_health = dns_health
+ self.unexpected_restarts = unexpected_restarts
+
+ def __str__(self):
+ any_restarts = any([
+ restarts.count
+ for restarts in self.unexpected_restarts.values()
+ ])
+ return textwrap.dedent("""
+ Running on {}
+ Version: {}:
+ {:20} {:10} {:15} {:10} {:>10} {:>10}
+ {}
+
+ Internet health: {}
+ DNS health: {}
+
+ Restart summary:
+ {}
+ """).format(
+ self.version, self.platform,
+ 'Service', 'Status', 'SubState', 'Running for', 'Log level',
+ 'Errors since last restart',
+ '\n'.join([str(h) for h in self.services_health]),
+ self.internet_health, self.dns_health,
+ '\n'.join([
+ '{:20} {}'.format(name, restarts)
+ for name, restarts
+ in self.unexpected_restarts.items()
+ ])
+ if any_restarts
+ else "No restarts since the gateway started",
+ )
diff --git a/common/health/health_service.py b/common/health/health_service.py
new file mode 100644
index 0000000..4228330
--- /dev/null
+++ b/common/health/health_service.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import os
+import subprocess
+from datetime import datetime
+
+import apt
+from dateutil import tz
+from common.health.entities import (
+ ActiveState,
+ Errors,
+ HealthStatus,
+ HealthSummary,
+ RestartFrequency,
+ ServiceHealth,
+ Version,
+)
+from common.service import MagmaService
+from common.service_registry import ServiceRegistry
+from configuration.mconfig_managers import load_service_mconfig_as_json
+from magmad.metrics import UNEXPECTED_SERVICE_RESTARTS
+from magmad.service_poller import ServicePoller
+from orc8r.protos import common_pb2, magmad_pb2
+from orc8r.protos.magmad_pb2_grpc import MagmadStub
+from orc8r.protos.mconfig import mconfigs_pb2
+from pystemd.systemd1 import Unit
+
+
+class GenericHealthChecker:
+
+ def ping(self, host, num_packets=4):
+ chan = ServiceRegistry.get_rpc_channel('magmad', ServiceRegistry.LOCAL)
+ client = MagmadStub(chan)
+
+ response = client.RunNetworkTests(
+ magmad_pb2.NetworkTestRequest(
+ pings=[
+ magmad_pb2.PingParams(
+ host_or_ip=host,
+ num_packets=num_packets,
+ ),
+ ],
+ ),
+ )
+ return response.pings
+
+ def ping_status(self, host):
+ pings = self.ping(host=host, num_packets=4)[0]
+ if pings.error:
+ return HealthStatus.DOWN
+ if pings.avg_response_ms:
+ return HealthStatus.UP
+ return HealthStatus.UNKNOWN
+
+ def get_error_summary(self, service_names):
+ """Get the list of services with the error count.
+
+ Args:
+ service_names: List of service names.
+
+ Returns:
+ A dictionary with service name as a key and the Errors object
+ as a value.
+
+ Raises:
+ PermissionError: User has no permision to exectue the command
+ """
+ configs = {
+ service_name: load_service_mconfig_as_json(service_name)
+ for service_name in service_names
+ }
+ res = {
+ service_name: Errors(
+ log_level=configs[service_name].get('logLevel', 'INFO'),
+ error_count=0,
+ )
+ for service_name in service_names
+ }
+
+ syslog_path = '/var/log/syslog'
+ if not os.access(syslog_path, os.R_OK):
+ raise PermissionError(
+ 'syslog is not readable. '
+ 'Try `sudo chmod a+r {}`. '
+ 'Or execute the command with sudo '
+ 'permissions: `venvsudo`'.format(syslog_path),
+ )
+ with open(syslog_path, 'r', encoding='utf-8') as f:
+ for line in f:
+ for service_name in service_names:
+ if service_name not in line:
+ continue
+ # Reset the counter for restart/start
+ if 'Starting {}...'.format(service_name) in line:
+ res[service_name].error_count = 0
+ elif 'ERROR' in line:
+ res[service_name].error_count += 1
+ return res
+
+ def get_magma_services_summary(self):
+ """ Get health for all the running services """
+ services_health_summary = []
+
+ # DBus objects: https://www.freedesktop.org/wiki/Software/systemd/dbus/
+ chan = ServiceRegistry.get_rpc_channel('magmad', ServiceRegistry.LOCAL)
+ client = MagmadStub(chan)
+
+ configs = client.GetConfigs(common_pb2.Void())
+
+ service_names = [str(name) for name in configs.configs_by_key]
+ services_errors = self.get_error_summary(service_names=service_names)
+
+ for service_name in service_names:
+ unit = Unit(
+ 'magma@{}.service'.format(service_name),
+ _autoload=True,
+ )
+ active_state = ActiveState.dbus2state[unit.Unit.ActiveState]
+ sub_state = str(unit.Unit.SubState, 'utf-8')
+ if active_state == ActiveState.ACTIVE:
+ pid = unit.Service.MainPID
+ process = subprocess.Popen(
+ 'ps -o etime= -p {}'.format(pid).split(),
+ stdout=subprocess.PIPE,
+ )
+
+ time_running, error = process.communicate()
+ if error:
+ raise ValueError(
+ 'Cannot get time running for the service '
+ '{} `ps -o etime= -p {}`'
+ .format(service_name, pid),
+ )
+ else:
+ time_running = b'00'
+
+ services_health_summary.append(
+ ServiceHealth(
+ service_name=service_name,
+ active_state=active_state, sub_state=sub_state,
+ time_running=str(time_running, 'utf-8').strip(),
+ errors=services_errors[service_name],
+ ),
+ )
+ return services_health_summary
+
+ def get_unexpected_restart_summary(self):
+ service = MagmaService('magmad', mconfigs_pb2.MagmaD())
+ service_poller = ServicePoller(service.loop, service.config)
+ service_poller.start()
+
+ asyncio.set_event_loop(service.loop)
+
+ # noinspection PyProtectedMember
+ # pylint: disable=protected-access
+ async def fetch_info():
+ restart_frequencies = {}
+ await service_poller._get_service_info()
+ for service_name in service_poller.service_info.keys():
+ restarts = int(
+ UNEXPECTED_SERVICE_RESTARTS
+ .labels(service_name=service_name)
+ ._value.get(),
+ )
+ restart_frequencies[service_name] = RestartFrequency(
+ count=restarts,
+ time_interval='',
+ )
+
+ return restart_frequencies
+
+ return service.loop.run_until_complete(fetch_info())
+
+ def get_kernel_version(self):
+ info, error = subprocess.Popen(
+ 'uname -a'.split(),
+ stdout=subprocess.PIPE,
+ ).communicate()
+
+ if error:
+ raise ValueError('Cannot get the kernel version')
+ return str(info, 'utf-8')
+
+ def get_magma_version(self):
+ cache = apt.Cache()
+
+ # Return the python version if magma is not there
+ if 'magma' not in cache:
+ return Version(
+ version_code=cache['python3'].versions[0],
+ last_update_time='-',
+ )
+
+ pkg = str(cache['magma'].versions[0])
+ version = pkg.split('-')[0].split('=')[-1]
+ timestamp = int(pkg.split('-')[1])
+
+ return Version(
+ version_code=version,
+ last_update_time=datetime.utcfromtimestamp(timestamp)
+ .replace(tzinfo=tz.tzutc())
+ .astimezone(tz=tz.tzlocal())
+ .strftime('%Y-%m-%d %H:%M:%S'),
+ )
+
+ def get_health_summary(self):
+
+ return HealthSummary(
+ version=self.get_magma_version(),
+ platform=self.get_kernel_version(),
+ services_health=self.get_magma_services_summary(),
+ internet_health=self.ping_status(host='8.8.8.8'),
+ dns_health=self.ping_status(host='google.com'),
+ unexpected_restarts=self.get_unexpected_restart_summary(),
+ )
diff --git a/common/health/service_state_wrapper.py b/common/health/service_state_wrapper.py
new file mode 100644
index 0000000..7c1f707
--- /dev/null
+++ b/common/health/service_state_wrapper.py
@@ -0,0 +1,88 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from common.redis.client import get_default_client
+from common.redis.containers import RedisFlatDict
+from common.redis.serializers import (
+ RedisSerde,
+ get_proto_deserializer,
+ get_proto_serializer,
+)
+from orc8r.protos.service_status_pb2 import ServiceExitStatus
+
+
+class ServiceStateWrapper:
+ """
+ Class wraps ServiceState interactions with redis
+ """
+
+ # Unique typename for Redis key
+ REDIS_VALUE_TYPE = "systemd_status"
+
+ def __init__(self):
+ serde = RedisSerde(
+ self.REDIS_VALUE_TYPE,
+ get_proto_serializer(),
+ get_proto_deserializer(ServiceExitStatus),
+ )
+ self._flat_dict = RedisFlatDict(get_default_client(), serde)
+
+ def update_service_status(
+ self, service_name: str,
+ service_status: ServiceExitStatus,
+ ) -> None:
+ """
+ Update the service exit status for a given service
+ """
+
+ if service_name in self._flat_dict:
+ current_service_status = self._flat_dict[service_name]
+ else:
+ current_service_status = ServiceExitStatus()
+
+ if service_status.latest_service_result == \
+ ServiceExitStatus.ServiceResult.Value("SUCCESS"):
+ service_status.num_clean_exits = \
+ current_service_status.num_clean_exits + 1
+ service_status.num_fail_exits = \
+ current_service_status.num_fail_exits
+ else:
+ service_status.num_fail_exits = \
+ current_service_status.num_fail_exits + 1
+ service_status.num_clean_exits = \
+ current_service_status.num_clean_exits
+ self._flat_dict[service_name] = service_status
+
+ def get_service_status(self, service_name: str) -> ServiceExitStatus:
+ """
+ Get the service status protobuf for a given service
+ @returns ServiceStatus protobuf object
+ """
+ return self._flat_dict[service_name]
+
+ def get_all_services_status(self) -> [str, ServiceExitStatus]:
+ """
+ Get a dict of service name to service status
+ @return dict of service_name to service map
+ """
+ service_status = {}
+ for k, v in self._flat_dict.items():
+ service_status[k] = v
+ return service_status
+
+ def cleanup_service_status(self) -> None:
+ """
+ Cleanup service status for all services in redis, mostly using for
+ testing
+ """
+ self._flat_dict.clear()
diff --git a/common/job.py b/common/job.py
new file mode 100644
index 0000000..f3ba0d2
--- /dev/null
+++ b/common/job.py
@@ -0,0 +1,129 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import abc
+import asyncio
+import logging
+import time
+from contextlib import suppress
+from typing import Optional, cast
+
+
+class Job(abc.ABC):
+ """
+ This is a base class that provides functions for a specific task to
+ ensure regular completion of the loop.
+
+ A co-routine run must be implemented by a subclass.
+ periodic() will call the co-routine at a regular interval set by
+ self._interval.
+ """
+
+ def __init__(
+ self,
+ interval: int,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ) -> None:
+ if loop is None:
+ self._loop = asyncio.get_event_loop()
+ else:
+ self._loop = loop
+ # Task in charge of periodically running the task
+ self._periodic_task = cast(Optional[asyncio.Task], None)
+ # Task in charge of deciding how long to wait until next run
+ self._interval_wait_task = cast(Optional[asyncio.Task], None)
+ self._interval = interval # in seconds
+ self._last_run = cast(Optional[float], None)
+ self._timeout = cast(Optional[float], None)
+ # Condition variable used to control how long the job waits until
+ # executing its task again.
+ self._cond = self._cond = asyncio.Condition(loop=self._loop)
+
+ @abc.abstractmethod
+ async def _run(self):
+ """
+ Once implemented by a subclass, this function will contain the actual
+ work of this Job.
+ """
+ pass
+
+ def start(self) -> None:
+ """
+ kicks off the _periodic while loop
+ """
+ if self._periodic_task is None:
+ self._periodic_task = self._loop.create_task(self._periodic())
+
+ def stop(self) -> None:
+ """
+ cancels the _periodic while loop
+ """
+ if self._periodic_task is not None:
+ self._periodic_task.cancel()
+ with suppress(asyncio.CancelledError):
+ # Await task to execute it's cancellation
+ self._loop.run_until_complete(self._periodic_task)
+ self._periodic_task = None
+
+ def set_timeout(self, timeout: float) -> None:
+ self._timeout = timeout
+
+ def set_interval(self, interval: int) -> None:
+ """
+ sets the interval used in _periodic to decide how long to sleep
+ """
+ self._interval = interval
+
+ def heartbeat(self) -> None:
+ # record time to keep track of iteration length
+ self._last_run = time.time()
+
+ def not_completed(self, current_time: float) -> bool:
+ last_time = self._last_run
+
+ if last_time is None:
+ return True
+ if last_time < current_time - (self._timeout or 120):
+ return True
+ return False
+
+ async def _sleep_for_interval(self):
+ await asyncio.sleep(self._interval)
+ async with self._cond:
+ self._cond.notify()
+
+ async def wake_up(self):
+ """
+ Cancels the _sleep_for_interval task if it exists, and notifies the
+ cond var so that the _periodic loop can continue.
+ """
+ if self._interval_wait_task is not None:
+ self._interval_wait_task.cancel()
+
+ async with self._cond:
+ self._cond.notify()
+
+ async def _periodic(self) -> None:
+ while True:
+ self.heartbeat()
+
+ try:
+ await self._run()
+ except Exception as exp: # pylint: disable=broad-except
+ logging.exception("Exception from _run: %s", exp)
+
+ # Wait for self._interval seconds or wake_up is explicitly called
+ self._interval_wait_task = \
+ self._loop.create_task(self._sleep_for_interval())
+ async with self._cond:
+ await self._cond.wait()
diff --git a/common/log_count_handler.py b/common/log_count_handler.py
new file mode 100644
index 0000000..55b5ecf
--- /dev/null
+++ b/common/log_count_handler.py
@@ -0,0 +1,35 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+
+
+class MsgCounterHandler(logging.Handler):
+ """ Register this handler to logging to count the logs by level """
+
+ count_by_level = None
+
+ def __init__(self, *args, **kwargs):
+ super(MsgCounterHandler, self).__init__(*args, **kwargs)
+ self.count_by_level = {}
+
+ def emit(self, record: logging.LogRecord):
+ level = record.levelname
+ if (level not in self.count_by_level):
+ self.count_by_level[level] = 0
+ self.count_by_level[level] += 1
+
+ def pop_error_count(self) -> int:
+ error_count = self.count_by_level.get('ERROR', 0)
+ self.count_by_level['ERROR'] = 0
+ return error_count
diff --git a/common/log_counter.py b/common/log_counter.py
new file mode 100644
index 0000000..4f2aea3
--- /dev/null
+++ b/common/log_counter.py
@@ -0,0 +1,40 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+from typing import Any
+
+from common.job import Job
+from common.log_count_handler import MsgCounterHandler
+from common.metrics import SERVICE_ERRORS
+
+# How frequently to poll systemd for error logs, in seconds
+POLL_INTERVAL = 10
+
+
+class ServiceLogErrorReporter(Job):
+ """ Reports the number of logged errors for the service """
+
+ def __init__(
+ self,
+ loop: asyncio.BaseEventLoop,
+ service_config: Any,
+ handler: MsgCounterHandler,
+ ) -> None:
+ super().__init__(interval=POLL_INTERVAL, loop=loop)
+ self._service_config = service_config
+ self._handler = handler
+
+ async def _run(self):
+ error_count = self._handler.pop_error_count()
+ SERVICE_ERRORS.inc(error_count)
diff --git a/common/metrics.py b/common/metrics.py
new file mode 100644
index 0000000..ebb3bae
--- /dev/null
+++ b/common/metrics.py
@@ -0,0 +1,24 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from prometheus_client import Counter
+
+STREAMER_RESPONSES = Counter(
+ 'streamer_responses',
+ 'The number of responses by label',
+ ['result'],
+)
+
+SERVICE_ERRORS = Counter(
+ 'service_errors',
+ 'The number of errors logged',
+)
diff --git a/common/metrics_export.py b/common/metrics_export.py
new file mode 100644
index 0000000..964e3fb
--- /dev/null
+++ b/common/metrics_export.py
@@ -0,0 +1,212 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+import time
+
+import metrics_pb2
+from orc8r.protos import metricsd_pb2
+from prometheus_client import REGISTRY
+
+
+def get_metrics(registry=REGISTRY, verbose=False):
+ """
+ Collects timeseries samples from prometheus metric collector registry
+ adds a common timestamp, and encodes them to protobuf
+
+ Arguments:
+ regsitry: a prometheus CollectorRegistry instance
+ verbose: whether to optimize for bandwidth and ignore metric name/help
+
+ Returns:
+ a prometheus MetricFamily protobuf stream
+ """
+ timestamp_ms = int(time.time() * 1000)
+ for metric_family in registry.collect():
+ if metric_family.type in ('counter', 'gauge'):
+ family_proto = encode_counter_gauge(metric_family, timestamp_ms)
+ elif metric_family.type == 'summary':
+ family_proto = encode_summary(metric_family, timestamp_ms)
+ elif metric_family.type == 'histogram':
+ family_proto = encode_histogram(metric_family, timestamp_ms)
+
+ if verbose:
+ family_proto.help = metric_family.documentation
+ family_proto.name = metric_family.name
+ else:
+ try:
+ family_proto.name = \
+ str(metricsd_pb2.MetricName.Value(metric_family.name))
+ except ValueError as e:
+ logging.debug(e) # If enum is not defined
+ family_proto.name = metric_family.name
+ yield family_proto
+
+
+def encode_counter_gauge(family, timestamp_ms):
+ """
+ Takes a Counter/Gauge family which is a collection of timeseries
+ samples that share a name (uniquely identified by labels) and yields
+ equivalent protobufs.
+
+ Each timeseries corresponds to a single sample tuple of the format:
+ (NAME, LABELS, VALUE)
+
+ Arguments:
+ family: a prometheus gauge metric family
+ timestamp_ms: the timestamp to attach to the samples
+ Raises:
+ ValueError if metric name is not defined in MetricNames protobuf
+ Returns:
+ A Counter or Gauge prometheus MetricFamily protobuf
+ """
+ family_proto = metrics_pb2.MetricFamily()
+ family_proto.type = \
+ metrics_pb2.MetricType.Value(family.type.upper())
+ for sample in family.samples:
+ metric_proto = metrics_pb2.Metric()
+ if family_proto.type == metrics_pb2.COUNTER:
+ metric_proto.counter.value = sample[2]
+ elif family_proto.type == metrics_pb2.GAUGE:
+ metric_proto.gauge.value = sample[2]
+ # Add meta-data to the timeseries
+ metric_proto.timestamp_ms = timestamp_ms
+ metric_proto.label.extend(_convert_labels_to_enums(sample[1].items()))
+ # Append metric sample to family
+ family_proto.metric.extend([metric_proto])
+ return family_proto
+
+
+def encode_summary(family, timestamp_ms):
+ """
+ Takes a Summary Metric family which is a collection of timeseries
+ samples that share a name (uniquely identified by labels) and yields
+ equivalent protobufs.
+
+ Each summary timeseries consists of sample tuples for the count, sum,
+ and quantiles in the format (NAME,LABELS,VALUE). The NAME is suffixed
+ with either _count, _sum to indicate count and sum respectively.
+ Quantile samples will be of the same NAME with quantile label.
+
+ Arguments:
+ family: a prometheus summary metric family
+ timestamp_ms: the timestamp to attach to the samples
+ Raises:
+ ValueError if metric name is not defined in MetricNames protobuf
+ Returns:
+ a Summary prometheus MetricFamily protobuf
+ """
+ family_proto = metrics_pb2.MetricFamily()
+ family_proto.type = metrics_pb2.SUMMARY
+ metric_protos = {}
+ # Build a map of each of the summary timeseries from the samples
+ for sample in family.samples:
+ quantile = sample[1].pop('quantile', None) # Remove from label set
+ # Each time series identified by label set excluding the quantile
+ metric_proto = \
+ metric_protos.setdefault(
+ frozenset(sample[1].items()),
+ metrics_pb2.Metric(),
+ )
+ if sample[0].endswith('_count'):
+ metric_proto.summary.sample_count = int(sample[2])
+ elif sample[0].endswith('_sum'):
+ metric_proto.summary.sample_sum = sample[2]
+ elif quantile:
+ quantile = metric_proto.summary.quantile.add()
+ quantile.value = sample[2]
+ quantile.quantile = _goStringToFloat(quantile)
+ # Go back and add meta-data to the timeseries
+ for labels, metric_proto in metric_protos.items():
+ metric_proto.timestamp_ms = timestamp_ms
+ metric_proto.label.extend(_convert_labels_to_enums(labels))
+ # Add it to the family
+ family_proto.metric.extend([metric_proto])
+ return family_proto
+
+
+def encode_histogram(family, timestamp_ms):
+ """
+ Takes a Histogram Metric family which is a collection of timeseries
+ samples that share a name (uniquely identified by labels) and yields
+ equivalent protobufs.
+
+ Each summary timeseries consists of sample tuples for the count, sum,
+ and quantiles in the format (NAME,LABELS,VALUE). The NAME is suffixed
+ with either _count, _sum, _buckets to indicate count, sum and buckets
+ respectively. Bucket samples will also contain a le to indicate its
+ upper bound.
+
+ Arguments:
+ family: a prometheus histogram metric family
+ timestamp_ms: the timestamp to attach to the samples
+ Raises:
+ ValueError if metric name is not defined in MetricNames protobuf
+ Returns:
+ a Histogram prometheus MetricFamily protobuf
+ """
+ family_proto = metrics_pb2.MetricFamily()
+ family_proto.type = metrics_pb2.HISTOGRAM
+ metric_protos = {}
+ for sample in family.samples:
+ upper_bound = sample[1].pop('le', None) # Remove from label set
+ metric_proto = \
+ metric_protos.setdefault(
+ frozenset(sample[1].items()),
+ metrics_pb2.Metric(),
+ )
+ if sample[0].endswith('_count'):
+ metric_proto.histogram.sample_count = int(sample[2])
+ elif sample[0].endswith('_sum'):
+ metric_proto.histogram.sample_sum = sample[2]
+ elif sample[0].endswith('_bucket'):
+ quantile = metric_proto.histogram.bucket.add()
+ quantile.cumulative_count = int(sample[2])
+ quantile.upper_bound = _goStringToFloat(upper_bound)
+ # Go back and add meta-data to the timeseries
+ for labels, metric_proto in metric_protos.items():
+ metric_proto.timestamp_ms = timestamp_ms
+ metric_proto.label.extend(_convert_labels_to_enums(labels))
+ # Add it to the family
+ family_proto.metric.extend([metric_proto])
+ return family_proto
+
+
+def _goStringToFloat(s):
+ if s == '+Inf':
+ return float("inf")
+ elif s == '-Inf':
+ return float("-inf")
+ elif s == 'NaN':
+ return float('nan')
+ else:
+ return float(s)
+
+
+def _convert_labels_to_enums(labels):
+ """
+ Try to convert both the label names and label values to enum values.
+ Defaults to the given name and value if it fails to convert.
+ Arguments:
+ labels: an array of label pairs that may contain enum names
+ Returns:
+ an array of label pairs with enum names converted to enum values
+ """
+ new_labels = []
+ for name, value in labels:
+ try:
+ name = str(metricsd_pb2.MetricLabelName.Value(name))
+ except ValueError as e:
+ logging.debug(e)
+ new_labels.append(metrics_pb2.LabelPair(name=name, value=value))
+ return new_labels
diff --git a/common/misc_utils.py b/common/misc_utils.py
new file mode 100644
index 0000000..4d8e5ec
--- /dev/null
+++ b/common/misc_utils.py
@@ -0,0 +1,269 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import ipaddress
+import os
+from enum import Enum
+
+import netifaces
+import snowflake
+
+
+class IpPreference(Enum):
+ IPV4_ONLY = 1
+ IPV4_PREFERRED = 2
+ IPV6_PREFERRED = 3
+ IPV6_ONLY = 4
+
+
+def get_if_ip_with_netmask(interface, preference=IpPreference.IPV4_PREFERRED):
+ """
+ Get IP address and netmask (in form /255.255.255.0)
+ from interface name and return as tuple (ip, netmask).
+ Note: If multiple v4/v6 addresses exist, the first is chosen
+
+ Raise ValueError if unable to get requested IP address.
+ """
+ # Raises ValueError if interface is unavailable
+ ip_addresses = netifaces.ifaddresses(interface)
+
+ try:
+ ipv4_address = (
+ ip_addresses[netifaces.AF_INET][0]['addr'],
+ ip_addresses[netifaces.AF_INET][0]['netmask'],
+ )
+ except KeyError:
+ ipv4_address = None
+
+ try:
+ ipv6_address = (
+ ip_addresses[netifaces.AF_INET6][0]["addr"].split("%")[0],
+ ip_addresses[netifaces.AF_INET6][0]["netmask"],
+ )
+
+ except KeyError:
+ ipv6_address = None
+
+ if preference == IpPreference.IPV4_ONLY:
+ if ipv4_address is not None:
+ return ipv4_address
+ else:
+ raise ValueError('Error getting IPv4 address for %s' % interface)
+
+ elif preference == IpPreference.IPV4_PREFERRED:
+ if ipv4_address is not None:
+ return ipv4_address
+ elif ipv6_address is not None:
+ return ipv6_address
+ else:
+ raise ValueError('Error getting IPv4/6 address for %s' % interface)
+
+ elif preference == IpPreference.IPV6_PREFERRED:
+ if ipv6_address is not None:
+ return ipv6_address
+ elif ipv4_address is not None:
+ return ipv4_address
+ else:
+ raise ValueError('Error getting IPv6/4 address for %s' % interface)
+
+ elif preference == IpPreference.IPV6_ONLY:
+ if ipv6_address is not None:
+ return ipv6_address
+ else:
+ raise ValueError('Error getting IPv6 address for %s' % interface)
+
+ else:
+ raise ValueError('Unknown IP preference %s' % preference)
+
+
+def get_all_if_ips_with_netmask(
+ interface,
+ preference=IpPreference.IPV4_PREFERRED,
+):
+ """
+ Get all IP addresses and netmasks (in form /255.255.255.0)
+ from interface name and return as a list of tuple (ip, netmask).
+
+ Raise ValueError if unable to get requested IP addresses.
+ """
+ # Raises ValueError if interface is unavailable
+ ip_addresses = netifaces.ifaddresses(interface)
+
+ try:
+ ipv4_addresses = [(ip_address['addr'], ip_address['netmask']) for
+ ip_address in ip_addresses[netifaces.AF_INET]]
+ except KeyError:
+ ipv4_addresses = None
+
+ try:
+ ipv6_addresses = [(ip_address['addr'], ip_address['netmask']) for
+ ip_address in ip_addresses[netifaces.AF_INET6]]
+ except KeyError:
+ ipv6_addresses = None
+
+ if preference == IpPreference.IPV4_ONLY:
+ if ipv4_addresses is not None:
+ return ipv4_addresses
+ else:
+ raise ValueError('Error getting IPv4 addresses for %s' % interface)
+
+ elif preference == IpPreference.IPV4_PREFERRED:
+ if ipv4_addresses is not None:
+ return ipv4_addresses
+ elif ipv6_addresses is not None:
+ return ipv6_addresses
+ else:
+ raise ValueError(
+ 'Error getting IPv4/6 addresses for %s' % interface,
+ )
+
+ elif preference == IpPreference.IPV6_PREFERRED:
+ if ipv6_addresses is not None:
+ return ipv6_addresses
+ elif ipv4_addresses is not None:
+ return ipv4_addresses
+ else:
+ raise ValueError(
+ 'Error getting IPv6/4 addresses for %s' % interface,
+ )
+
+ elif preference == IpPreference.IPV6_ONLY:
+ if ipv6_addresses is not None:
+ return ipv6_addresses
+ else:
+ raise ValueError('Error getting IPv6 addresses for %s' % interface)
+
+ else:
+ raise ValueError('Unknown IP preference %s' % preference)
+
+
+def get_ip_from_if(iface_name, preference=IpPreference.IPV4_PREFERRED):
+ """
+ Get ip address from interface name and return as string.
+ Extract only ip address from (ip, netmask)
+ """
+ return get_if_ip_with_netmask(iface_name, preference)[0]
+
+
+def get_all_ips_from_if(iface_name, preference=IpPreference.IPV4_PREFERRED):
+ """
+ Get all ip addresses from interface name and return as a list of string.
+ Extract only ip address from (ip, netmask)
+ """
+ return [
+ ip[0] for ip in
+ get_all_if_ips_with_netmask(iface_name, preference)
+ ]
+
+
+def get_ip_from_if_cidr(iface_name, preference=IpPreference.IPV4_PREFERRED):
+ """
+ Get IPAddress with netmask from interface name and
+ transform into CIDR (eth1 -> 192.168.60.142/24)
+ notation return as string.
+ """
+ ip, netmask = get_if_ip_with_netmask(iface_name, preference)
+ ip = '%s/%s' % (ip, netmask)
+ interface = ipaddress.ip_interface(ip).with_prefixlen # Set CIDR notation
+ return interface
+
+
+def get_all_ips_from_if_cidr(
+ iface_name,
+ preference=IpPreference.IPV4_PREFERRED,
+):
+ """
+ Get all IPAddresses with netmask from interface name and
+ transform into CIDR (eth1 -> 192.168.60.142/24) notation
+ return as a list of string.
+ """
+
+ def ip_cidr_gen():
+ for ip, netmask in get_all_if_ips_with_netmask(iface_name, preference):
+ ip = '%s/%s' % (ip, netmask)
+ # Set CIDR notation
+ ip_cidr = ipaddress.ip_interface(ip).with_prefixlen
+ yield ip_cidr
+
+ return [ip_cidr for ip_cidr in ip_cidr_gen()]
+
+
+def cidr_to_ip_netmask_tuple(cidr_network):
+ """
+ Convert CIDR-format IP network string (e.g. 10.0.0.1/24) to a tuple
+ (ip, netmask) where netmask is in the form (n.n.n.n).
+
+ Args:
+ cidr_network (str): IPv4 network in CIDR notation
+
+ Returns:
+ (str, str): 2-tuple of IP address and netmask
+ """
+ network = ipaddress.ip_network(cidr_network)
+ return '{}'.format(network.network_address), '{}'.format(network.netmask)
+
+
+def get_if_mac_address(interface):
+ """
+ Returns the MAC address of an interface.
+ Note: If multiple MAC addresses exist, the first one is chosen.
+
+ Raise ValueError if unable to get requested IP address.
+ """
+ addr = netifaces.ifaddresses(interface)
+ try:
+ return addr[netifaces.AF_LINK][0]['addr']
+ except KeyError:
+ raise ValueError('Error getting MAC address for %s' % interface)
+
+
+def get_gateway_hwid() -> str:
+ """
+ Returns the HWID of the gateway
+ Note: Currently this uses the snowflake at /etc/snowflake
+ """
+ return snowflake.snowflake()
+
+
+def is_interface_up(interface):
+ """
+ Returns whether an interface is up.
+ """
+ try:
+ addr = netifaces.ifaddresses(interface)
+ except ValueError:
+ return False
+ return netifaces.AF_INET in addr
+
+
+def call_process(cmd, callback, loop):
+ loop = loop or asyncio.get_event_loop()
+ loop.create_task(
+ loop.subprocess_shell(
+ lambda: SubprocessProtocol(callback), "nohup " + cmd,
+ preexec_fn=os.setsid,
+ ),
+ )
+
+
+class SubprocessProtocol(asyncio.SubprocessProtocol):
+ def __init__(self, callback):
+ self._callback = callback
+ self._transport = None
+
+ def connection_made(self, transport):
+ self._transport = transport
+
+ def process_exited(self):
+ self._callback(self._transport.get_returncode())
diff --git a/common/redis/__init__.py b/common/redis/__init__.py
new file mode 100644
index 0000000..5c6cb64
--- /dev/null
+++ b/common/redis/__init__.py
@@ -0,0 +1,12 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
diff --git a/common/redis/client.py b/common/redis/client.py
new file mode 100644
index 0000000..65acd87
--- /dev/null
+++ b/common/redis/client.py
@@ -0,0 +1,23 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import redis
+from configuration.service_configs import get_service_config_value
+
+
+def get_default_client():
+ """
+ Return a default redis client using the configured port in redis.yml
+ """
+ redis_port = get_service_config_value('redis', 'port', 6379)
+ redis_addr = get_service_config_value('redis', 'bind', 'localhost')
+ return redis.Redis(host=redis_addr, port=redis_port)
diff --git a/common/redis/containers.py b/common/redis/containers.py
new file mode 100644
index 0000000..c227e4d
--- /dev/null
+++ b/common/redis/containers.py
@@ -0,0 +1,444 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from copy import deepcopy
+from typing import Any, Iterator, List, MutableMapping, Optional, TypeVar
+
+import redis
+import redis_collections
+import redis_lock
+from common.redis.serializers import RedisSerde
+from orc8r.protos.redis_pb2 import RedisState
+from redis.lock import Lock
+
+# NOTE: these containers replace the serialization methods exposed by
+# the redis-collection objects. Although the methods are hinted to be
+# privately scoped, the method replacement is encouraged in the library's
+# docs: http://redis-collections.readthedocs.io/en/stable/usage-notes.html
+
+T = TypeVar('T')
+
+
+class RedisList(redis_collections.List):
+ """
+ List-like interface serializing elements to a Redis datastore.
+
+ Notes:
+ - Provides persistence across sessions
+ - Mutable elements handled correctly
+ - Not expected to be thread safe, but could be extended
+ """
+
+ def __init__(self, client, key, serialize, deserialize):
+ """
+ Initialize instance.
+
+ Args:
+ client (redis.Redis): Redis client object
+ key (str): key where this container's elements are stored in Redis
+ serialize (function (any) -> bytes):
+ function called to serialize an element
+ deserialize (function (bytes) -> any):
+ function called to deserialize an element
+ Returns:
+ redis_list (redis_collections.List): persistent list-like interface
+ """
+ self._pickle = serialize
+ self._unpickle = deserialize
+ super().__init__(redis=client, key=key, writeback=True)
+
+ def __copy__(self):
+ return [elt for elt in self]
+
+ def __deepcopy__(self, memo):
+ return [deepcopy(elt, memo) for elt in self]
+
+
+class RedisSet(redis_collections.Set):
+ """
+ Set-like interface serializing elements to a Redis datastore.
+
+ Notes:
+ - Provides persistence across sessions
+ - Mutable elements _not_ handled correctly:
+ - Get/set mutable elements supported
+ - Don't update the contents of a mutable element and
+ expect things to go well
+ - Expected to be thread safe, but not tested
+ """
+
+ def __init__(self, client, key, serialize, deserialize):
+ """
+ Initialize instance.
+
+ Args:
+ client (redis.Redis): Redis client object
+ key (str): key where this container's elements are stored in Redis
+ serialize (function (any) -> bytes):
+ function called to serialize an element
+ deserialize (function (bytes) -> any):
+ function called to deserialize an element
+ Returns:
+ redis_set (redis_collections.Set): persistent set-like interface
+ """
+ # NOTE: redis_collections.Set doesn't have a writeback option, causing
+ # issue when mutable elements are updated in-place.
+ self._pickle = serialize
+ self._unpickle = deserialize
+ super().__init__(redis=client, key=key)
+
+ def __copy__(self):
+ return {elt for elt in self}
+
+ def __deepcopy__(self, memo):
+ return {deepcopy(elt, memo) for elt in self}
+
+
+class RedisHashDict(redis_collections.DefaultDict):
+ """
+ Dict-like interface serializing elements to a Redis datastore. This dict
+ utilizes Redis's hashmap functionality
+
+ Notes:
+ - Keys must be string-like and are serialized to plaintext (UTF-8)
+ - Provides persistence across sessions
+ - Mutable elements handled correctly
+ - Not expected to be thread safe, but could be extended
+ - Keys are serialized in plaintext
+ """
+
+ @staticmethod
+ def serialize_key(key):
+ """ Serialize key to plaintext. """
+ return key
+
+ @staticmethod
+ def deserialize_key(serialized):
+ """ Deserialize key from plaintext encoded as UTF-8 bytes. """
+ return serialized.decode('utf-8') # Redis returns bytes
+
+ def __init__(
+ self, client, key, serialize, deserialize,
+ default_factory=None, writeback=False,
+ ):
+ """
+ Initialize instance.
+
+ Args:
+ client (redis.Redis): Redis client object
+ key (str): key where this container's elements are stored in Redis
+ serialize (function (any) -> bytes):
+ function called to serialize a value
+ deserialize (function (bytes) -> any):
+ function called to deserialize a value
+ default_factory: function that provides default value for a
+ non-existent key
+ writeback (bool): if writeback is set to true, dict maintains a
+ local cache of values and the `sync` method can be called to
+ store these values. NOTE: only use this option if syncing
+ between services is not important.
+
+ Returns:
+ redis_dict (redis_collections.Dict): persistent dict-like interface
+ """
+ # Key serialization (to/from plaintext)
+ self._pickle_key = RedisHashDict.serialize_key
+ self._unpickle_key = RedisHashDict.deserialize_key
+ # Value serialization
+ self._pickle_value = serialize
+ self._unpickle = deserialize
+ super().__init__(
+ default_factory, redis=client, key=key, writeback=writeback,
+ )
+
+ def __setitem__(self, key, value):
+ """Set ``d[key]`` to *value*.
+
+ Override in order to increment version on each update
+ """
+ version = self.get_version(key)
+ pickled_key = self._pickle_key(key)
+ pickled_value = self._pickle_value(value, version + 1)
+ self.redis.hset(self.key, pickled_key, pickled_value)
+
+ if self.writeback:
+ self.cache[key] = value
+
+ def __copy__(self):
+ return {key: self[key] for key in self}
+
+ def __deepcopy__(self, memo):
+ return {key: deepcopy(self[key], memo) for key in self}
+
+ def get_version(self, key):
+ """Return the version of the value for key *key*. Returns 0 if
+ key is not in the map
+ """
+ try:
+ value = self.cache[key]
+ except KeyError:
+ pickled_key = self._pickle_key(key)
+ value = self.redis.hget(self.key, pickled_key)
+ if value is None:
+ return 0
+
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(value)
+ return proto_wrapper.version
+
+
+class RedisFlatDict(MutableMapping[str, T]):
+ """
+ Dict-like interface serializing elements to a Redis datastore. This
+ dict stores key directly (i.e. without a hashmap).
+ """
+
+ def __init__(
+ self, client: redis.Redis, serde: RedisSerde[T],
+ writethrough: bool = False,
+ ):
+ """
+ Args:
+ client (redis.Redis): Redis client object
+ serde (): RedisSerde for de/serializing the object stored
+ writethrough (bool): if writethrough is set to true,
+ RedisFlatDict maintains a local write-through cache of values.
+ """
+ super().__init__()
+ self._writethrough = writethrough
+ self.redis = client
+ self.serde = serde
+ self.redis_type = serde.redis_type
+ self.cache = {}
+ if self._writethrough:
+ self._sync_cache()
+
+ def __len__(self) -> int:
+ """Return the number of items in the dictionary."""
+ if self._writethrough:
+ return len(self.cache)
+
+ return len(self.keys())
+
+ def __iter__(self) -> Iterator[str]:
+ """Return an iterator over the keys of the dictionary."""
+ type_pattern = self._get_redis_type_pattern()
+
+ if self._writethrough:
+ for k in self.cache:
+ split_key, _ = k.split(":", 1)
+ yield split_key
+ else:
+ for k in self.redis.keys(pattern=type_pattern):
+ try:
+ deserialized_key = k.decode('utf-8')
+ split_key = deserialized_key.split(":", 1)
+ except AttributeError:
+ split_key = k.split(":", 1)
+ # There could be a delete key in between KEYS and GET, so ignore
+ # invalid values for now
+ try:
+ if self.is_garbage(split_key[0]):
+ continue
+ except KeyError:
+ continue
+ yield split_key[0]
+
+ def __contains__(self, key: str) -> bool:
+ """Return ``True`` if *key* is present and not garbage,
+ else ``False``.
+ """
+ composite_key = self._make_composite_key(key)
+
+ if self._writethrough:
+ return composite_key in self.cache
+
+ return bool(self.redis.exists(composite_key)) and \
+ not self.is_garbage(key)
+
+ def __getitem__(self, key: str) -> T:
+ """Return the item of dictionary with key *key:type*. Raises a
+ :exc:`KeyError` if *key:type* is not in the map or the object is
+ garbage
+ """
+ if ':' in key:
+ raise ValueError("Key %s cannot contain ':' char" % key)
+ composite_key = self._make_composite_key(key)
+
+ if self._writethrough:
+ cached_value = self.cache.get(composite_key)
+ if cached_value:
+ return cached_value
+
+ serialized_value = self.redis.get(composite_key)
+ if serialized_value is None:
+ raise KeyError(composite_key)
+
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(serialized_value)
+ if proto_wrapper.is_garbage:
+ raise KeyError("Key %s is garbage" % key)
+
+ return self.serde.deserialize(serialized_value)
+
+ def __setitem__(self, key: str, value: T) -> Any:
+ """Set ``d[key:type]`` to *value*."""
+ if ':' in key:
+ raise ValueError("Key %s cannot contain ':' char" % key)
+ version = self.get_version(key)
+ serialized_value = self.serde.serialize(value, version + 1)
+ composite_key = self._make_composite_key(key)
+ if self._writethrough:
+ self.cache[composite_key] = value
+ return self.redis.set(composite_key, serialized_value)
+
+ def __delitem__(self, key: str) -> int:
+ """Remove ``d[key:type]`` from dictionary.
+ Raises a :func:`KeyError` if *key:type* is not in the map.
+ """
+ if ':' in key:
+ raise ValueError("Key %s cannot contain ':' char" % key)
+ composite_key = self._make_composite_key(key)
+ if self._writethrough:
+ del self.cache[composite_key]
+ deleted_count = self.redis.delete(composite_key)
+ if not deleted_count:
+ raise KeyError(composite_key)
+ return deleted_count
+
+ def get(self, key: str, default=None) -> Optional[T]:
+ """Get ``d[key:type]`` from dictionary.
+ Returns None if *key:type* is not in the map
+ """
+ try:
+ return self.__getitem__(key)
+ except (KeyError, ValueError):
+ return default
+
+ def clear(self) -> None:
+ """
+ Clear all keys in the dictionary. Objects are immediately deleted
+ (i.e. not garbage collected)
+ """
+ if self._writethrough:
+ self.cache.clear()
+ for key in self.keys():
+ composite_key = self._make_composite_key(key)
+ self.redis.delete(composite_key)
+
+ def get_version(self, key: str) -> int:
+ """Return the version of the value for key *key:type*. Returns 0 if
+ key is not in the map
+ """
+ composite_key = self._make_composite_key(key)
+ value = self.redis.get(composite_key)
+ if value is None:
+ return 0
+
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(value)
+ return proto_wrapper.version
+
+ def keys(self) -> List[str]:
+ """Return a copy of the dictionary's list of keys
+ Note: for redis *key:type* key is returned
+ """
+ if self._writethrough:
+ return list(self.cache.keys())
+
+ return list(self.__iter__())
+
+ def mark_as_garbage(self, key: str) -> Any:
+ """Mark ``d[key:type]`` for garbage collection
+ Raises a KeyError if *key:type* is not in the map.
+ """
+ composite_key = self._make_composite_key(key)
+ value = self.redis.get(composite_key)
+ if value is None:
+ raise KeyError(composite_key)
+
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(value)
+ proto_wrapper.is_garbage = True
+ garbage_serialized = proto_wrapper.SerializeToString()
+ return self.redis.set(composite_key, garbage_serialized)
+
+ def is_garbage(self, key: str) -> bool:
+ """Return if d[key:type] has been marked for garbage collection.
+ Raises a KeyError if *key:type* is not in the map.
+ """
+ composite_key = self._make_composite_key(key)
+ value = self.redis.get(composite_key)
+ if value is None:
+ raise KeyError(composite_key)
+
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(value)
+ return proto_wrapper.is_garbage
+
+ def garbage_keys(self) -> List[str]:
+ """Return a copy of the dictionary's list of keys that are garbage
+ Note: for redis *key:type* key is returned
+ """
+ garbage_keys = []
+ type_pattern = self._get_redis_type_pattern()
+ for k in self.redis.keys(pattern=type_pattern):
+ try:
+ deserialized_key = k.decode('utf-8')
+ split_key = deserialized_key.split(":", 1)
+ except AttributeError:
+ split_key = k.split(":", 1)
+ # There could be a delete key in between KEYS and GET, so ignore
+ # invalid values for now
+ try:
+ if not self.is_garbage(split_key[0]):
+ continue
+ except KeyError:
+ continue
+ garbage_keys.append(split_key[0])
+ return garbage_keys
+
+ def delete_garbage(self, key) -> bool:
+ """Remove ``d[key:type]`` from dictionary iff the object is garbage
+ Returns False if *key:type* is not in the map
+ """
+ if not self.is_garbage(key):
+ return False
+ count = self.__delitem__(key)
+ return count > 0
+
+ def lock(self, key: str) -> Lock:
+ """Lock the dictionary for key *key*"""
+ return redis_lock.Lock(
+ self.redis,
+ name=self._make_composite_key(key) + ":lock",
+ expire=60,
+ auto_renewal=True,
+ strict=False,
+ )
+
+ def _sync_cache(self):
+ """
+ Syncs write-through cache with redis data on store.
+ """
+ type_pattern = self._get_redis_type_pattern()
+ for k in self.redis.keys(pattern=type_pattern):
+ composite_key = k.decode('utf-8')
+ serialized_value = self.redis.get(composite_key)
+ value = self.serde.deserialize(serialized_value)
+ self.cache[composite_key] = value
+
+ def _get_redis_type_pattern(self):
+ return "*:" + self.redis_type
+
+ def _make_composite_key(self, key):
+ return key + ":" + self.redis_type
diff --git a/common/redis/mocks/__init__.py b/common/redis/mocks/__init__.py
new file mode 100644
index 0000000..5c6cb64
--- /dev/null
+++ b/common/redis/mocks/__init__.py
@@ -0,0 +1,12 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
diff --git a/common/redis/mocks/mock_redis.py b/common/redis/mocks/mock_redis.py
new file mode 100644
index 0000000..0978932
--- /dev/null
+++ b/common/redis/mocks/mock_redis.py
@@ -0,0 +1,33 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from redis.exceptions import RedisError
+
+# For non-failure cases, just use the fakeredis module
+
+
+class MockUnavailableRedis(object):
+ """
+ MockUnavailableRedis implements a mock Redis Server that always raises
+ a connection exception
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+
+ def lock(self, key):
+ raise RedisError("mock redis error")
+
+ def keys(self, pattern=".*"):
+ """ Mock keys with regex pattern matching."""
+ raise RedisError("mock redis error")
diff --git a/common/redis/serializers.py b/common/redis/serializers.py
new file mode 100644
index 0000000..d8b01e1
--- /dev/null
+++ b/common/redis/serializers.py
@@ -0,0 +1,120 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from typing import Callable, Generic, Type, TypeVar
+
+import jsonpickle
+from orc8r.protos.redis_pb2 import RedisState
+
+T = TypeVar('T')
+
+
+class RedisSerde(Generic[T]):
+ """
+ typeval (str): str representing the type of object the serde can
+ de/serialize
+ serializer (function (T, int) -> str):
+ function called to serialize a value
+ deserializer (function (str) -> T):
+ function called to deserialize a value
+ """
+
+ def __init__(
+ self,
+ redis_type: str,
+ serializer: Callable[[T, int], str],
+ deserializer: Callable[[str], T],
+ ):
+ self.redis_type = redis_type
+ self.serializer = serializer
+ self.deserializer = deserializer
+
+ def serialize(self, msg: T, version: int = 1) -> str:
+ return self.serializer(msg, version)
+
+ def deserialize(self, serialized_obj: str) -> T:
+ return self.deserializer(serialized_obj)
+
+
+def get_proto_serializer() -> Callable[[T, int], str]:
+ """
+ Return a proto serializer that serializes the proto, adds the associated
+ version, and then serializes the RedisState proto to a string
+ """
+ def _serialize_proto(proto: T, version: int) -> str:
+ serialized_proto = proto.SerializeToString()
+ redis_state = RedisState(
+ serialized_msg=serialized_proto,
+ version=version,
+ is_garbage=False,
+ )
+ return redis_state.SerializeToString()
+ return _serialize_proto
+
+
+def get_proto_deserializer(proto_class: Type[T]) -> Callable[[str], T]:
+ """
+ Return a proto deserializer that takes in a proto type to deserialize
+ the serialized msg stored in the RedisState proto
+ """
+ def _deserialize_proto(serialized_rule: str) -> T:
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(serialized_rule)
+ serialized_proto = proto_wrapper.serialized_msg
+ proto = proto_class()
+ proto.ParseFromString(serialized_proto)
+ return proto
+ return _deserialize_proto
+
+
+def get_json_serializer() -> Callable[[T, int], str]:
+ """
+ Return a json serializer that serializes the json msg, adds the
+ associated version, and then serializes the RedisState proto to a string
+ """
+ def _serialize_json(msg: T, version: int) -> str:
+ serialized_msg = jsonpickle.encode(msg)
+ redis_state = RedisState(
+ serialized_msg=serialized_msg.encode('utf-8'),
+ version=version,
+ is_garbage=False,
+ )
+ return redis_state.SerializeToString()
+
+ return _serialize_json
+
+
+def get_json_deserializer() -> Callable[[str], T]:
+ """
+ Returns a json deserializer that deserializes the RedisState proto and
+ then deserializes the json msg
+ """
+ def _deserialize_json(serialized_rule: str) -> T:
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(serialized_rule)
+ serialized_msg = proto_wrapper.serialized_msg
+ msg = jsonpickle.decode(serialized_msg.decode('utf-8'))
+ return msg
+
+ return _deserialize_json
+
+
+def get_proto_version_deserializer() -> Callable[[str], T]:
+ """
+ Return a proto deserializer that takes in a proto type to deserialize
+ the version number stored in the RedisState proto
+ """
+ def _deserialize_version(serialized_rule: str) -> T:
+ proto_wrapper = RedisState()
+ proto_wrapper.ParseFromString(serialized_rule)
+ return proto_wrapper.version
+ return _deserialize_version
diff --git a/common/redis/tests/__init__.py b/common/redis/tests/__init__.py
new file mode 100644
index 0000000..5c6cb64
--- /dev/null
+++ b/common/redis/tests/__init__.py
@@ -0,0 +1,12 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
diff --git a/common/redis/tests/dict_tests.py b/common/redis/tests/dict_tests.py
new file mode 100644
index 0000000..e9508bc
--- /dev/null
+++ b/common/redis/tests/dict_tests.py
@@ -0,0 +1,179 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from unittest import TestCase, main
+
+import fakeredis
+from common.redis.containers import RedisFlatDict, RedisHashDict
+from common.redis.serializers import (
+ RedisSerde,
+ get_proto_deserializer,
+ get_proto_serializer,
+)
+from orc8r.protos.service303_pb2 import LogVerbosity
+
+
+class RedisDictTests(TestCase):
+ """
+ Tests for the RedisHashDict and RedisFlatDict containers
+ """
+
+ def setUp(self):
+ client = fakeredis.FakeStrictRedis()
+ # Use arbitrary orc8r proto to test with
+ self._hash_dict = RedisHashDict(
+ client,
+ "unittest",
+ get_proto_serializer(),
+ get_proto_deserializer(LogVerbosity),
+ )
+
+ serde = RedisSerde(
+ 'log_verbosity',
+ get_proto_serializer(),
+ get_proto_deserializer(LogVerbosity),
+ )
+ self._flat_dict = RedisFlatDict(client, serde)
+
+ def test_hash_insert(self):
+ expected = LogVerbosity(verbosity=0)
+ expected2 = LogVerbosity(verbosity=1)
+
+ # insert proto
+ self._hash_dict['key1'] = expected
+ version = self._hash_dict.get_version("key1")
+ actual = self._hash_dict['key1']
+ self.assertEqual(1, version)
+ self.assertEqual(expected, actual)
+
+ # update proto
+ self._hash_dict['key1'] = expected2
+ version2 = self._hash_dict.get_version("key1")
+ actual2 = self._hash_dict['key1']
+ self.assertEqual(2, version2)
+ self.assertEqual(expected2, actual2)
+
+ def test_missing_version(self):
+ missing_version = self._hash_dict.get_version("key2")
+ self.assertEqual(0, missing_version)
+
+ def test_hash_delete(self):
+ expected = LogVerbosity(verbosity=2)
+ self._hash_dict['key3'] = expected
+
+ actual = self._hash_dict['key3']
+ self.assertEqual(expected, actual)
+
+ self._hash_dict.pop('key3')
+ self.assertRaises(KeyError, self._hash_dict.__getitem__, 'key3')
+
+ def test_flat_insert(self):
+ expected = LogVerbosity(verbosity=5)
+ expected2 = LogVerbosity(verbosity=1)
+
+ # insert proto
+ self._flat_dict['key1'] = expected
+ version = self._flat_dict.get_version("key1")
+ actual = self._flat_dict['key1']
+ self.assertEqual(1, version)
+ self.assertEqual(expected, actual)
+
+ # update proto
+ self._flat_dict["key1"] = expected2
+ version2 = self._flat_dict.get_version("key1")
+ actual2 = self._flat_dict["key1"]
+ actual3 = self._flat_dict.get("key1")
+ self.assertEqual(2, version2)
+ self.assertEqual(expected2, actual2)
+ self.assertEqual(expected2, actual3)
+
+ def test_flat_missing_version(self):
+ missing_version = self._flat_dict.get_version("key2")
+ self.assertEqual(0, missing_version)
+
+ def test_flat_bad_key(self):
+ expected = LogVerbosity(verbosity=2)
+ self.assertRaises(
+ ValueError, self._flat_dict.__setitem__,
+ 'bad:key', expected,
+ )
+ self.assertRaises(
+ ValueError, self._flat_dict.__getitem__,
+ 'bad:key',
+ )
+ self.assertRaises(
+ ValueError, self._flat_dict.__delitem__,
+ 'bad:key',
+ )
+
+ def test_flat_delete(self):
+ expected = LogVerbosity(verbosity=2)
+ self._flat_dict['key3'] = expected
+
+ actual = self._flat_dict['key3']
+ self.assertEqual(expected, actual)
+
+ del self._flat_dict['key3']
+ self.assertRaises(
+ KeyError, self._flat_dict.__getitem__,
+ 'key3',
+ )
+ self.assertEqual(None, self._flat_dict.get('key3'))
+
+ def test_flat_clear(self):
+ expected = LogVerbosity(verbosity=2)
+ self._flat_dict['key3'] = expected
+
+ actual = self._flat_dict['key3']
+ self.assertEqual(expected, actual)
+
+ self._flat_dict.clear()
+ self.assertEqual(0, len(self._flat_dict.keys()))
+
+ def test_flat_garbage_methods(self):
+ expected = LogVerbosity(verbosity=2)
+ expected2 = LogVerbosity(verbosity=3)
+
+ key = "k1"
+ key2 = "k2"
+ bad_key = "bad_key"
+ self._flat_dict[key] = expected
+ self._flat_dict[key2] = expected2
+
+ self._flat_dict.mark_as_garbage(key)
+ is_garbage = self._flat_dict.is_garbage(key)
+ self.assertTrue(is_garbage)
+ is_garbage2 = self._flat_dict.is_garbage(key2)
+ self.assertFalse(is_garbage2)
+
+ self.assertEqual([key], self._flat_dict.garbage_keys())
+ self.assertEqual([key2], self._flat_dict.keys())
+
+ self.assertIsNone(self._flat_dict.get(key))
+ self.assertEqual(expected2, self._flat_dict.get(key2))
+
+ deleted = self._flat_dict.delete_garbage(key)
+ not_deleted = self._flat_dict.delete_garbage(key2)
+ self.assertTrue(deleted)
+ self.assertFalse(not_deleted)
+
+ self.assertIsNone(self._flat_dict.get(key))
+ self.assertEqual(expected2, self._flat_dict.get(key2))
+
+ with self.assertRaises(KeyError):
+ self._flat_dict.is_garbage(bad_key)
+ with self.assertRaises(KeyError):
+ self._flat_dict.mark_as_garbage(bad_key)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/common/rpc_utils.py b/common/rpc_utils.py
new file mode 100644
index 0000000..5fcbbf9
--- /dev/null
+++ b/common/rpc_utils.py
@@ -0,0 +1,183 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+# pylint: disable=broad-except
+
+import asyncio
+import logging
+from enum import Enum
+
+import grpc
+from google.protobuf import message as proto_message
+from google.protobuf.json_format import MessageToJson
+from common.service_registry import ServiceRegistry
+from orc8r.protos import common_pb2
+
+
+class RetryableGrpcErrorDetails(Enum):
+ """
+ Enum for gRPC retryable error detail messages
+ """
+ SOCKET_CLOSED = "Socket closed"
+ CONNECT_FAILED = "Connect Failed"
+
+
+def return_void(func):
+ """
+ Reusable decorator for returning common_pb2.Void() message.
+ """
+
+ def wrapper(*args, **kwargs):
+ func(*args, **kwargs)
+ return common_pb2.Void()
+
+ return wrapper
+
+
+def grpc_wrapper(func):
+ """
+ Wraps a function with a gRPC wrapper which creates a RPC client to
+ the service and handles any RPC Exceptions.
+
+ Usage:
+ @grpc_wrapper
+ def func(client, args):
+ pass
+ func(args, ProtoStubClass, 'service')
+ """
+
+ def wrapper(*alist):
+ args = alist[0]
+ stub_cls = alist[1]
+ service = alist[2]
+ chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.LOCAL)
+ client = stub_cls(chan)
+ try:
+ func(client, args)
+ except grpc.RpcError as err:
+ print("Error! [%s] %s" % (err.code(), err.details()))
+ exit(1)
+
+ return wrapper
+
+
+def cloud_grpc_wrapper(func):
+ """
+ Wraps a function with a gRPC wrapper which creates a RPC client to
+ the service and handles any RPC Exceptions.
+
+ Usage:
+ @cloud_grpc_wrapper
+ def func(client, args):
+ pass
+ func(args, ProtoStubClass, 'service')
+ """
+
+ def wrapper(*alist):
+ args = alist[0]
+ stub_cls = alist[1]
+ service = alist[2]
+ chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.CLOUD)
+ client = stub_cls(chan)
+ try:
+ func(client, args)
+ except grpc.RpcError as err:
+ print("Error! [%s] %s" % (err.code(), err.details()))
+ exit(1)
+
+ return wrapper
+
+
+def set_grpc_err(
+ context: grpc.ServicerContext,
+ code: grpc.StatusCode,
+ details: str,
+):
+ """
+ Sets status code and details for a gRPC context. Removes commas from
+ the details message (see https://github.com/grpc/grpc-node/issues/769)
+ """
+ context.set_code(code)
+ context.set_details(details.replace(',', ''))
+
+
+def _grpc_async_wrapper(f, gf):
+ try:
+ f.set_result(gf.result())
+ except Exception as e:
+ f.set_exception(e)
+
+
+def grpc_async_wrapper(gf, loop=None):
+ """
+ Wraps a GRPC result in a future that can be yielded by asyncio
+
+ Usage:
+
+ async def my_fn(param):
+ result =
+ await grpc_async_wrapper(stub.function_name.future(param, timeout))
+
+ Code taken and modified from:
+ https://github.com/grpc/grpc/wiki/Integration-with-tornado-(python)
+ """
+ f = asyncio.Future()
+ if loop is None:
+ loop = asyncio.get_event_loop()
+ gf.add_done_callback(
+ lambda _: loop.call_soon_threadsafe(_grpc_async_wrapper, f, gf),
+ )
+ return f
+
+
+def is_grpc_error_retryable(error: grpc.RpcError) -> bool:
+ status_code = error.code()
+ error_details = error.details()
+ if status_code == grpc.StatusCode.UNAVAILABLE and \
+ any(
+ err_msg.value in error_details for err_msg in
+ RetryableGrpcErrorDetails
+ ):
+ # server end closed connection.
+ return True
+ return False
+
+
+def print_grpc(
+ message: proto_message.Message, print_grpc_payload: bool,
+ message_header: str = "",
+):
+ """
+ Prints content of grpc message
+
+ Args:
+ message: grpc message to print
+ print_grpc_payload: flag to enable/disable printing of the message
+ message_header: header to print before printing grpc content
+ """
+
+ if print_grpc_payload:
+ log_msg = "{} {}".format(
+ message.DESCRIPTOR.full_name,
+ MessageToJson(message),
+ )
+ # add indentation
+ padding = 2 * ' '
+ log_msg = ''.join(
+ "{}{}".format(padding, line)
+ for line in log_msg.splitlines(True)
+ )
+ log_msg = "GRPC message:\n{}".format(log_msg)
+
+ if message_header:
+ logging.info(message_header)
+ logging.info(log_msg)
diff --git a/common/sdwatchdog.py b/common/sdwatchdog.py
new file mode 100644
index 0000000..2eb4e52
--- /dev/null
+++ b/common/sdwatchdog.py
@@ -0,0 +1,101 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+# pylint: disable=W0223
+
+import asyncio
+import logging
+import os
+import time
+from typing import List, Optional, Set, cast
+
+import systemd.daemon
+from common.job import Job
+
+
+class SDWatchdogTask(Job):
+ pass
+
+
+class SDWatchdog(object):
+ """
+ This is a task that utilizes systemd watchdog functionality.
+
+ SDWatchdog() task is started automatically in run in common/service.run(),
+ where it will look at every task in the loop to see if it is a subclass
+ of SDWatchdogTask
+
+ To enable systemd watchdog, add "WatchdogSec=60" in the [Service] section
+ of the systemd service file.
+ """
+
+ def __init__(
+ self,
+ tasks: Optional[List[SDWatchdogTask]],
+ update_status: bool = False, # update systemd status field
+ period: float = 30,
+ ) -> None:
+ """
+ coroutine that will check each task's time_last_completed_loop to
+ ensure that it was updated every in the last timeout_s seconds.
+
+ Perform check of each service every period seconds.
+ """
+
+ self.tasks = cast(Set[SDWatchdogTask], set())
+ self.update_status = update_status
+ self.period = period
+
+ if tasks:
+ for t in tasks:
+ if not issubclass(type(t), SDWatchdogTask):
+ logging.warning(
+ "'%s' is not a 'SDWatchdogTask', skipping", repr(t),
+ )
+ else:
+ self.tasks.add(t)
+
+ @staticmethod
+ def has_notify() -> bool:
+ return os.getenv("NOTIFY_SOCKET") is not None
+
+ async def run(self) -> None:
+ """
+ check tasks every self.period seconds to see if they have completed
+ a loop within the last 'timeout' seconds. If so, sd notify WATCHDOG=1
+ """
+ if not self.has_notify():
+ logging.warning("Missing 'NOTIFY_SOCKET' for SDWatchdog, skipping")
+ return
+ logging.info("Starting SDWatchdog...")
+ while True:
+ current_time = time.time()
+ anyStuck = False
+ for task in self.tasks:
+ if task.not_completed(current_time):
+ errmsg = "SDWatchdog service '%s' has not completed %s" % (
+ repr(task), time.asctime(time.gmtime(current_time)),
+ )
+ if self.update_status:
+ systemd.daemon.notify("STATUS=%s\n" % errmsg)
+ logging.info(errmsg)
+ anyStuck = True
+
+ if not anyStuck:
+ systemd.daemon.notify(
+ 'STATUS=SDWatchdog success %s\n' %
+ time.asctime(time.gmtime(current_time)),
+ )
+ systemd.daemon.notify("WATCHDOG=1")
+ systemd.daemon.notify("READY=1") # only active if Type=notify
+
+ await asyncio.sleep(self.period)
diff --git a/common/sentry.py b/common/sentry.py
new file mode 100644
index 0000000..eb433ed
--- /dev/null
+++ b/common/sentry.py
@@ -0,0 +1,49 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import os
+
+import sentry_sdk
+import snowflake
+from configuration.service_configs import get_service_config_value
+
+CONTROL_PROXY = 'control_proxy'
+SENTRY_URL = 'sentry_url_python'
+SENTRY_SAMPLE_RATE = 'sentry_sample_rate'
+COMMIT_HASH = 'COMMIT_HASH'
+HWID = 'hwid'
+SERVICE_NAME = 'service_name'
+
+
+def sentry_init(service_name: str):
+ """Initialize connection and start piping errors to sentry.io."""
+ sentry_url = get_service_config_value(
+ CONTROL_PROXY,
+ SENTRY_URL,
+ default='',
+ )
+ if not sentry_url:
+ return
+
+ sentry_sample_rate = get_service_config_value(
+ CONTROL_PROXY,
+ SENTRY_SAMPLE_RATE,
+ default=1.0,
+ )
+ sentry_sdk.init(
+ dsn=sentry_url,
+ release=os.getenv(COMMIT_HASH),
+ traces_sample_rate=sentry_sample_rate,
+ )
+ sentry_sdk.set_tag(HWID, snowflake.snowflake())
+ sentry_sdk.set_tag(SERVICE_NAME, service_name)
diff --git a/common/serialization_utils.py b/common/serialization_utils.py
new file mode 100644
index 0000000..dc13a5c
--- /dev/null
+++ b/common/serialization_utils.py
@@ -0,0 +1,37 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import codecs
+import os
+
+
+def write_to_file_atomically(filename, value, temp_filename=None):
+ """
+ Atomically write to a file by first writing the value to a temp file, then
+ moving that temp file to the specified file location.
+
+ This function will create all directories necessary for the file as well.
+
+ Args:
+ filename: full path to the file to write to
+ value: value to write to the file
+ temp_filename: requested path of the intermediate temp file
+ mode: mode to open the file
+ """
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ temp_filename = temp_filename or '{}.tmp'.format(filename)
+ with codecs.open(temp_filename, 'w', encoding='utf8') as f:
+ f.write(value)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(temp_filename, filename)
diff --git a/common/service.py b/common/service.py
new file mode 100644
index 0000000..2bcce61
--- /dev/null
+++ b/common/service.py
@@ -0,0 +1,450 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import faulthandler
+import functools
+import logging
+import os
+import signal
+import time
+from concurrent import futures
+from typing import Any, Dict, List, Optional
+
+import grpc
+import pkg_resources
+from common.log_count_handler import MsgCounterHandler
+from common.log_counter import ServiceLogErrorReporter
+from common.metrics_export import get_metrics
+from common.service_registry import ServiceRegistry
+from configuration.exceptions import LoadConfigError
+from configuration.mconfig_managers import get_mconfig_manager
+from configuration.service_configs import load_service_config
+from orc8r.protos.common_pb2 import LogLevel, Void
+from orc8r.protos.metricsd_pb2 import MetricsContainer
+from orc8r.protos.service303_pb2 import (
+ GetOperationalStatesResponse,
+ ReloadConfigResponse,
+ ServiceInfo,
+ State,
+)
+from orc8r.protos.service303_pb2_grpc import (
+ Service303Servicer,
+ Service303Stub,
+ add_Service303Servicer_to_server,
+)
+
+MAX_DEFAULT_WORKER = 10
+
+
+async def loop_exit():
+ """
+ Stop the loop in an async context
+ """
+ loop = asyncio.get_event_loop()
+ loop.stop()
+
+
+class MagmaService(Service303Servicer):
+ """
+ MagmaService provides the framework for all Magma services.
+ This class also implements the Service303 interface for external
+ entities to interact with the service.
+ """
+
+ def __init__(self, name, empty_mconfig, loop=None):
+ self._name = name
+ self._port = 0
+ self._get_status_callback = None
+ self._get_operational_states_cb = None
+ self._log_count_handler = MsgCounterHandler()
+
+ # Init logging before doing anything
+ logging.basicConfig(
+ level=logging.INFO,
+ format='[%(asctime)s %(levelname)s %(name)s] %(message)s',
+ )
+ # Add a handler to count errors
+ logging.root.addHandler(self._log_count_handler)
+
+ # Set gRPC polling strategy
+ self._set_grpc_poll_strategy()
+
+ # Load the managed config if present
+ self._mconfig = empty_mconfig
+ self._mconfig_metadata = None
+ self._mconfig_manager = get_mconfig_manager()
+ self.reload_mconfig()
+
+ self._state = ServiceInfo.STARTING
+ self._health = ServiceInfo.APP_UNHEALTHY
+ if loop is None:
+ loop = asyncio.get_event_loop()
+ self._loop = loop
+ self._start_time = int(time.time())
+ self._register_signal_handlers()
+
+ # Load the service config if present
+ self._config = None
+ self.reload_config()
+
+ # Count errors
+ self.log_counter = ServiceLogErrorReporter(
+ loop=self._loop,
+ service_config=self._config,
+ handler=self._log_count_handler,
+ )
+ self.log_counter.start()
+
+ # Operational States
+ self._operational_states = []
+
+ self._version = '0.0.0'
+ # Load the service version if available
+ try:
+ # Check if service on docker
+ if self._config and 'init_system' in self._config \
+ and self._config['init_system'] == 'docker':
+ # image comes in form of "feg_gateway_python:<IMAGE_TAG>\n"
+ # Skip the "feg_gateway_python:" part
+ image = os.popen(
+ 'docker ps --filter name=magmad --format "{{.Image}}" | '
+ 'cut -d ":" -f 2',
+ )
+ image_tag = image.read().strip('\n')
+ self._version = image_tag
+ else:
+ self._version = pkg_resources.get_distribution('orc8r').version
+ except pkg_resources.ResolutionError as e:
+ logging.info(e)
+
+ if self._config and 'grpc_workers' in self._config:
+ self._server = grpc.server(
+ futures.ThreadPoolExecutor(
+ max_workers=self._config['grpc_workers'],
+ ),
+ )
+ else:
+ self._server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=MAX_DEFAULT_WORKER),
+ )
+ add_Service303Servicer_to_server(self, self._server)
+
+ @property
+ def version(self):
+ """Return the current running version of the Magma service"""
+ return self._version
+
+ @property
+ def name(self):
+ """Return the name of service
+
+ Returns:
+ tr: name of service
+ """
+ return self._name
+
+ @property
+ def rpc_server(self):
+ """Return the RPC server used by the service"""
+ return self._server
+
+ @property
+ def port(self):
+ """Return the listening port of the service"""
+ return self._port
+
+ @property
+ def loop(self):
+ """Return the asyncio event loop used by the service"""
+ return self._loop
+
+ @property
+ def state(self):
+ """Return the state of the service"""
+ return self._state
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """Return the service config"""
+ return self._config
+
+ @property
+ def mconfig(self):
+ """Return the managed config"""
+ return self._mconfig
+
+ @property
+ def mconfig_metadata(self):
+ """Return the metadata of the managed config"""
+ return self._mconfig_metadata
+
+ @property
+ def mconfig_manager(self):
+ """Return the mconfig manager for this service"""
+ return self._mconfig_manager
+
+ def reload_config(self):
+ """Reload the local config for the service"""
+ try:
+ self._config = load_service_config(self._name)
+ self._setup_logging()
+ except LoadConfigError as e:
+ logging.warning(e)
+
+ def reload_mconfig(self):
+ """Reload the managed config for the service"""
+ try:
+ # reload mconfig manager in case feature flag for streaming changed
+ self._mconfig_manager = get_mconfig_manager()
+ self._mconfig = self._mconfig_manager.load_service_mconfig(
+ self._name,
+ self._mconfig,
+ )
+ self._mconfig_metadata = \
+ self._mconfig_manager.load_mconfig_metadata()
+ except LoadConfigError as e:
+ logging.warning(e)
+
+ def add_operational_states(self, states: List[State]):
+ """Add a list of states into the service
+
+ Args:
+ states (List[State]): [description]
+ """
+ self._operational_states.extend(states)
+
+ def run(self):
+ """
+ Start the service and runs the event loop until a term signal
+ is received or a StopService rpc call is made on the Service303
+ interface.
+ """
+ logging.info("Starting %s...", self._name)
+ (host, port) = ServiceRegistry.get_service_address(self._name)
+ self._port = self._server.add_insecure_port('{}:{}'.format(host, port))
+ logging.info("Listening on address %s:%d", host, self._port)
+ self._state = ServiceInfo.ALIVE
+ # Python services are healthy immediately when run
+ self._health = ServiceInfo.APP_HEALTHY
+ self._server.start()
+ self._loop.run_forever()
+ # Waiting for the term signal or StopService rpc call
+
+ def close(self):
+ """
+ Clean up the service before termination. This needs to be
+ called atleast once after the service has been inited.
+ """
+ self._loop.close()
+ self._server.stop(0).wait()
+
+ def register_get_status_callback(self, get_status_callback):
+ """Register function for getting status
+
+ Must return a map(string, string)
+ """
+ self._get_status_callback = get_status_callback
+
+ def register_operational_states_callback(self, get_operational_states_cb):
+ """Register the callback function that gets called on GetOperationalStates rpc
+
+ Args:
+ get_operational_states_cb ([type]): callback function
+ """
+ self._get_operational_states_cb = get_operational_states_cb
+
+ def _stop(self, reason):
+ """Stop the service gracefully"""
+ logging.info("Stopping %s with reason %s...", self._name, reason)
+ self._state = ServiceInfo.STOPPING
+ self._server.stop(0)
+
+ for pending_task in asyncio.Task.all_tasks(self._loop):
+ pending_task.cancel()
+
+ self._state = ServiceInfo.STOPPED
+ self._health = ServiceInfo.APP_UNHEALTHY
+ asyncio.ensure_future(loop_exit())
+
+ def _set_grpc_poll_strategy(self):
+ """
+ The new default 'epollex' poll strategy is causing fd leaks, leading to
+ service crashes after 1024 open fds.
+ See https://github.com/grpc/grpc/issues/15759
+ """
+ os.environ['GRPC_POLL_STRATEGY'] = 'epoll1,poll'
+
+ def _get_log_level_from_config(self) -> Optional[int]:
+ if self._config is None:
+ return None
+ log_level = self._config.get('log_level', None)
+ if log_level is None:
+ return None
+ # convert from log level string to LogLevel enum value
+ try:
+ proto_level = LogLevel.Value(log_level)
+ except ValueError:
+ logging.error(
+ 'Unknown logging level in config: %s, ignoring',
+ log_level,
+ )
+ return None
+ return proto_level
+
+ def _get_log_level_from_mconfig(self) -> Optional[int]:
+ if self._mconfig is None:
+ return None
+ return self._mconfig.log_level
+
+ def _setup_logging(self):
+ """Set up log level from config values
+
+ The config file on the AGW takes precedence over the mconfig
+ If neither config file nor mconfig has the log level config, default to INFO
+ """
+ log_level_from_config = self._get_log_level_from_config()
+ log_level_from_mconfig = self._get_log_level_from_mconfig()
+
+ if log_level_from_config is not None:
+ log_level = log_level_from_config
+ elif log_level_from_mconfig is not None:
+ log_level = log_level_from_mconfig
+ else:
+ logging.warning(
+ 'logging level is not specified in either yml config '
+ 'or mconfig, defaulting to INFO',
+ )
+ log_level = LogLevel.Value('INFO')
+ self._set_log_level(log_level)
+
+ @staticmethod
+ def _set_log_level(proto_level: int):
+ """Set log level based on proto-enum level
+
+ Args:
+ proto_level (int): proto enum defined in common.proto
+ """
+ if proto_level == LogLevel.Value('DEBUG'):
+ level = logging.DEBUG
+ elif proto_level == LogLevel.Value('INFO'):
+ level = logging.INFO
+ elif proto_level == LogLevel.Value('WARNING'):
+ level = logging.WARNING
+ elif proto_level == LogLevel.Value('ERROR'):
+ level = logging.ERROR
+ elif proto_level == LogLevel.Value('FATAL'):
+ level = logging.FATAL
+ else:
+ logging.error(
+ 'Unknown logging level: %d, defaulting to INFO',
+ proto_level,
+ )
+ level = logging.INFO
+
+ logging.info(
+ "Setting logging level to %s",
+ logging.getLevelName(level),
+ )
+ logger = logging.getLogger('')
+ logger.setLevel(level)
+
+ def _register_signal_handlers(self):
+ """Register signal handlers
+
+ Right now we just exit on SIGINT/SIGTERM/SIGQUIT
+ """
+ for signame in ['SIGINT', 'SIGTERM', 'SIGQUIT']:
+ self._loop.add_signal_handler(
+ getattr(signal, signame),
+ functools.partial(self._stop, signame),
+ )
+
+ def _signal_handler():
+ logging.info('Handling SIGHUP...')
+ faulthandler.dump_traceback()
+ self._loop.add_signal_handler(
+ signal.SIGHUP, functools.partial(_signal_handler),
+ )
+
+ def GetServiceInfo(self, request, context):
+ """Return the service info (name, version, state, meta, etc.)"""
+ service_info = ServiceInfo(
+ name=self._name,
+ version=self._version,
+ state=self._state,
+ health=self._health,
+ start_time_secs=self._start_time,
+ )
+ if self._get_status_callback is not None:
+ status = self._get_status_callback()
+ try:
+ service_info.status.meta.update(status)
+ except (TypeError, ValueError) as exp:
+ logging.error("Error getting service status: %s", exp)
+ return service_info
+
+ def StopService(self, request, context):
+ """Handle request to stop the service"""
+ logging.info("Request to stop service.")
+ self._loop.call_soon_threadsafe(self._stop, 'RPC')
+ return Void()
+
+ def GetMetrics(self, request, context):
+ """
+ Collects timeseries samples from prometheus python client on this
+ process
+ """
+ metrics = MetricsContainer()
+ metrics.family.extend(get_metrics())
+ return metrics
+
+ def SetLogLevel(self, request, context):
+ """Handle request to set the log level"""
+ self._set_log_level(request.level)
+ return Void()
+
+ def SetLogVerbosity(self, request, context):
+ pass # Not Implemented
+
+ def ReloadServiceConfig(self, request, context):
+ """Handle request to reload the service config file"""
+ self.reload_config()
+ return ReloadConfigResponse(result=ReloadConfigResponse.RELOAD_SUCCESS)
+
+ def GetOperationalStates(self, request, context):
+ """Return the operational states of devices managed by this service."""
+ res = GetOperationalStatesResponse()
+ if self._get_operational_states_cb is not None:
+ states = self._get_operational_states_cb()
+ res.states.extend(states)
+ return res
+
+
+def get_service303_client(service_name: str, location: str) \
+ -> Optional[Service303Stub]:
+ """
+ Return a grpc client attached to the given service
+ name and location.
+ Example Use: client = get_service303_client("state", ServiceRegistry.LOCAL)
+ """
+ try:
+ chan = ServiceRegistry.get_rpc_channel(
+ service_name,
+ location,
+ )
+ return Service303Stub(chan)
+ except ValueError:
+ # Service can't be contacted
+ logging.error('Failed to get RPC channel to %s', service_name)
+ return None
diff --git a/common/service_registry.py b/common/service_registry.py
new file mode 100644
index 0000000..044852f
--- /dev/null
+++ b/common/service_registry.py
@@ -0,0 +1,301 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+import os
+
+import grpc
+from configuration.exceptions import LoadConfigError
+from configuration.service_configs import load_service_config
+
+GRPC_KEEPALIVE_MS = 30 * 1000
+
+
+class ServiceRegistry:
+ """
+ ServiceRegistry provides the framework to discover services.
+
+ ServiceRegistry takes care of service naming, and sets the connection
+ params like ip/port, TLS, certs, etc based on service level configuration.
+ """
+
+ _REGISTRY = {}
+ _PROXY_CONFIG = {}
+ _CHANNELS_CACHE = {}
+
+ LOCAL = 'local'
+ CLOUD = 'cloud'
+
+ @staticmethod
+ def get_service_address(service):
+ """
+ Returns the (host, port) tuple for the service.
+
+ Args:
+ service (string): Name of the service
+ Returns:
+ (host, port) tuple
+ Raises:
+ ValueError if the service is unknown
+ """
+ registry = ServiceRegistry.get_registry()
+ if service not in registry["services"]:
+ raise ValueError("Invalid service name: %s" % service)
+ service_conf = registry["services"][service]
+ return service_conf["ip_address"], service_conf["port"]
+
+ @staticmethod
+ def add_service(name, ip_address, port):
+ """
+ Adds a service to the registry.
+
+ Args:
+ name (string): Service name
+ ip_address (string): ip address string
+ port (int): service port
+ """
+ registry = ServiceRegistry.get_registry()
+ service = {"ip_address": ip_address, "port": port}
+ registry["services"][name] = service
+
+ @staticmethod
+ def list_services():
+ """
+ Returns the list of services in the registry.
+
+ Returns:
+ list of services
+ """
+ return ServiceRegistry.get_registry()["services"]
+
+ @staticmethod
+ def reset():
+ """
+ Removes all the entries in the registry
+ """
+ ServiceRegistry.get_registry()["services"] = {}
+
+ @staticmethod
+ def get_bootstrap_rpc_channel():
+ """
+ Returns a RPC channel to the bootstrap service in CLOUD.
+ Returns:
+ grpc channel
+ """
+ proxy_config = ServiceRegistry.get_proxy_config()
+ (ip, port) = (
+ proxy_config['bootstrap_address'],
+ proxy_config['bootstrap_port'],
+ )
+ authority = proxy_config['bootstrap_address']
+
+ try:
+ rootca = open(proxy_config['rootca_cert'], 'rb').read()
+ except FileNotFoundError as exp:
+ raise ValueError("SSL cert not found: %s" % exp)
+
+ ssl_creds = grpc.ssl_channel_credentials(rootca)
+ return create_grpc_channel(ip, port, authority, ssl_creds)
+
+ @staticmethod
+ def get_rpc_channel(
+ service, destination, proxy_cloud_connections=True,
+ grpc_options=None,
+ ):
+ """
+ Returns a RPC channel to the service. The connection params
+ are obtained from the service registry and used.
+ TBD: pool connections to a service and reuse them. Right
+ now each call creates a new TCP/SSL/HTTP2 connection.
+
+ Args:
+ service (string): Name of the service
+ destination (string): ServiceRegistry.LOCAL or ServiceRegistry.CLOUD
+ proxy_cloud_connections (bool): Override to connect direct to cloud
+ grpc_options (list): list of gRPC options params for the channel
+ Returns:
+ grpc channel
+ Raises:
+ ValueError if the service is unknown
+ """
+ proxy_config = ServiceRegistry.get_proxy_config()
+
+ # Control proxy uses the :authority: HTTP header to route to services.
+ if destination == ServiceRegistry.LOCAL:
+ authority = '%s.local' % (service)
+ else:
+ authority = '%s-%s' % (service, proxy_config['cloud_address'])
+
+ should_use_proxy = proxy_config['proxy_cloud_connections'] and \
+ proxy_cloud_connections
+
+ # If speaking to a local service or to the proxy, the grpc channel
+ # can be reused. If speaking to the cloud directly, the client cert
+ # could become stale after the next bootstrapper run.
+ should_reuse_channel = should_use_proxy or \
+ (destination == ServiceRegistry.LOCAL)
+ if should_reuse_channel:
+ channel = ServiceRegistry._CHANNELS_CACHE.get(authority, None)
+ if channel is not None:
+ return channel
+
+ if grpc_options is None:
+ grpc_options = [
+ ("grpc.keepalive_time_ms", GRPC_KEEPALIVE_MS),
+ ]
+ # We need to figure out the ip and port to connnect, if we need to use
+ # SSL and the authority to use.
+ if destination == ServiceRegistry.LOCAL:
+ # Connect to the local service directly
+ (ip, port) = ServiceRegistry.get_service_address(service)
+ channel = create_grpc_channel(
+ ip, port, authority,
+ options=grpc_options,
+ )
+ elif should_use_proxy:
+ # Connect to the cloud via local control proxy
+ try:
+ (ip, unused_port) = ServiceRegistry.get_service_address(
+ "control_proxy",
+ )
+ port = proxy_config['local_port']
+ except ValueError as err:
+ logging.error(err)
+ (ip, port) = ('127.0.0.1', proxy_config['local_port'])
+ channel = create_grpc_channel(
+ ip, port, authority,
+ options=grpc_options,
+ )
+ else:
+ # Connect to the cloud directly
+ ip = proxy_config['cloud_address']
+ port = proxy_config['cloud_port']
+ ssl_creds = get_ssl_creds()
+ channel = create_grpc_channel(
+ ip, port, authority, ssl_creds,
+ options=grpc_options,
+ )
+ if should_reuse_channel:
+ ServiceRegistry._CHANNELS_CACHE[authority] = channel
+ return channel
+
+ @staticmethod
+ def get_registry():
+ """
+ Returns _REGISTRY which holds the contents from the
+ config/service/service_registry.yml file. Its a static member and the
+ .yml file is loaded only once.
+ """
+ if not ServiceRegistry._REGISTRY:
+ try:
+ ServiceRegistry._REGISTRY = load_service_config(
+ "service_registry",
+ )
+ except LoadConfigError as err:
+ logging.error(err)
+ ServiceRegistry._REGISTRY = {"services": {}}
+ return ServiceRegistry._REGISTRY
+
+ @staticmethod
+ def get_proxy_config():
+ """
+ Returns the control proxy config. The config file is loaded only
+ once and cached.
+ """
+ if not ServiceRegistry._PROXY_CONFIG:
+ try:
+ ServiceRegistry._PROXY_CONFIG = load_service_config(
+ 'control_proxy',
+ )
+ except LoadConfigError as err:
+ logging.error(err)
+ ServiceRegistry._PROXY_CONFIG = {
+ 'proxy_cloud_connections': True,
+ }
+ return ServiceRegistry._PROXY_CONFIG
+
+
+def set_grpc_cipher_suites():
+ """
+ Set the cipher suites to be used for the gRPC TLS connection.
+ TODO (praveenr) t19265877: Update nghttpx in the cloud to recent version
+ and delete this. The current nghttpx version doesn't support the
+ ciphers needed by default for gRPC.
+ """
+ os.environ["GRPC_SSL_CIPHER_SUITES"] = "ECDHE-ECDSA-AES256-GCM-SHA384:"\
+ "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:"\
+ "ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:"\
+ "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:"\
+ "ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES128-SHA256:"\
+ "ECDHE-RSA-AES128-SHA256"
+
+
+def get_ssl_creds():
+ """
+ Get the SSL credentials to use to communicate securely.
+ We use client side TLS auth, with the cert and keys
+ obtained during bootstrapping of the gateway.
+
+ Returns:
+ gRPC ssl creds
+ Raises:
+ ValueError if the cert or key filename in the
+ control proxy config is incorrect.
+ """
+ proxy_config = ServiceRegistry.get_proxy_config()
+ try:
+ with open(proxy_config['rootca_cert'], 'rb') as rootca_f:
+ with open(proxy_config['gateway_cert'], encoding="utf-8") as cert_f:
+ with open(proxy_config['gateway_key'], encoding="utf-8") as key_f:
+ rootca = rootca_f.read()
+ cert = cert_f.read().encode()
+ key = key_f.read().encode()
+ ssl_creds = grpc.ssl_channel_credentials(
+ root_certificates=rootca,
+ certificate_chain=cert,
+ private_key=key,
+ )
+ except FileNotFoundError as exp:
+ raise ValueError("SSL cert not found: %s" % exp)
+ return ssl_creds
+
+
+def create_grpc_channel(ip, port, authority, ssl_creds=None, options=None):
+ """
+ Helper function to create a grpc channel.
+
+ Args:
+ ip: IP address of the remote endpoint
+ port: port of the remote endpoint
+ authority: HTTP header that control proxy uses for routing
+ ssl_creds: Enables SSL
+ options: configuration options for gRPC channel
+ Returns:
+ grpc channel
+ """
+ grpc_options = [('grpc.default_authority', authority)]
+ if options is not None:
+ grpc_options.extend(options)
+ if ssl_creds is not None:
+ set_grpc_cipher_suites()
+ channel = grpc.secure_channel(
+ target='%s:%s' % (ip, port),
+ credentials=ssl_creds,
+ options=grpc_options,
+ )
+ else:
+ channel = grpc.insecure_channel(
+ target='%s:%s' % (ip, port),
+ options=grpc_options,
+ )
+ return channel
diff --git a/common/stateless_agw.py b/common/stateless_agw.py
new file mode 100644
index 0000000..fb24924
--- /dev/null
+++ b/common/stateless_agw.py
@@ -0,0 +1,93 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import logging
+import subprocess
+
+from configuration.service_configs import (
+ load_override_config,
+ load_service_config,
+ save_override_config,
+)
+from orc8r.protos import magmad_pb2
+
+STATELESS_SERVICE_CONFIGS = [
+ ("mme", "use_stateless", True),
+ ("mobilityd", "persist_to_redis", True),
+ ("pipelined", "clean_restart", False),
+ ("pipelined", "redis_enabled", True),
+ ("sessiond", "support_stateless", True),
+]
+
+
+def check_stateless_agw():
+ num_stateful = 0
+ for service, config, value in STATELESS_SERVICE_CONFIGS:
+ if (
+ _check_stateless_service_config(service, config, value)
+ == magmad_pb2.CheckStatelessResponse.STATEFUL
+ ):
+ num_stateful += 1
+
+ if num_stateful == 0:
+ res = magmad_pb2.CheckStatelessResponse.STATELESS
+ elif num_stateful == len(STATELESS_SERVICE_CONFIGS):
+ res = magmad_pb2.CheckStatelessResponse.STATEFUL
+ else:
+ res = magmad_pb2.CheckStatelessResponse.CORRUPT
+
+ logging.debug(
+ "Check returning %s", magmad_pb2.CheckStatelessResponse.AGWMode.Name(
+ res,
+ ),
+ )
+ return res
+
+
+def enable_stateless_agw():
+ if check_stateless_agw() == magmad_pb2.CheckStatelessResponse.STATELESS:
+ logging.info("Nothing to enable, AGW is stateless")
+ for service, config, value in STATELESS_SERVICE_CONFIGS:
+ cfg = load_override_config(service) or {}
+ cfg[config] = value
+ save_override_config(service, cfg)
+
+ # restart Sctpd so that eNB connections are reset and local state cleared
+ _restart_sctpd()
+
+
+def disable_stateless_agw():
+ if check_stateless_agw() == magmad_pb2.CheckStatelessResponse.STATEFUL:
+ logging.info("Nothing to disable, AGW is stateful")
+ for service, config, value in STATELESS_SERVICE_CONFIGS:
+ cfg = load_override_config(service) or {}
+ cfg[config] = not value
+ save_override_config(service, cfg)
+
+ # restart Sctpd so that eNB connections are reset and local state cleared
+ _restart_sctpd()
+
+
+def _check_stateless_service_config(service, config_name, config_value):
+ service_config = load_service_config(service)
+ if service_config.get(config_name) == config_value:
+ logging.info("STATELESS\t%s -> %s", service, config_name)
+ return magmad_pb2.CheckStatelessResponse.STATELESS
+
+ logging.info("STATEFUL\t%s -> %s", service, config_name)
+ return magmad_pb2.CheckStatelessResponse.STATEFUL
+
+
+def _restart_sctpd():
+ logging.info("Restarting sctpd")
+ subprocess.call("service sctpd restart".split())
diff --git a/common/streamer.py b/common/streamer.py
new file mode 100644
index 0000000..01632d3
--- /dev/null
+++ b/common/streamer.py
@@ -0,0 +1,192 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import abc
+import logging
+import threading
+import time
+from typing import Any, List
+
+import grpc
+import snowflake
+from google.protobuf import any_pb2
+from common import serialization_utils
+from common.metrics import STREAMER_RESPONSES
+from common.service_registry import ServiceRegistry
+from configuration.service_configs import get_service_config_value
+from orc8r.protos.streamer_pb2 import DataUpdate, StreamRequest
+from orc8r.protos.streamer_pb2_grpc import StreamerStub
+
+
+class StreamerClient(threading.Thread):
+ """
+ StreamerClient provides an interface to communicate with the Streamer
+ service in the cloud to get updates for a stream.
+
+ The StreamerClient spawns a thread which listens to updates and
+ schedules a callback in the asyncio event loop when an update
+ is received from the cloud.
+
+ If the connection to the cloud gets terminated, the StreamerClient
+ would retry (TBD: with exponential backoff) to connect back to the cloud.
+ """
+
+ class Callback:
+
+ @abc.abstractmethod
+ def get_request_args(self, stream_name: str) -> Any:
+ """
+ This is called before every stream request to collect any extra
+ arguments to send up to the cloud streamer service.
+
+ Args:
+ stream_name:
+ Name of the stream that the request arg will be sent to
+
+ Returns: A protobuf message
+ """
+ pass
+
+ @abc.abstractmethod
+ def process_update(
+ self, stream_name: str, updates: List[DataUpdate],
+ resync: bool,
+ ):
+ """
+ Called when we get an update from the cloud. This method will
+ be called in the event loop provided to the StreamerClient.
+
+ Args:
+ stream_name: Name of the stream
+ updates: Array of updates
+ resync: if true, the application can clear the
+ contents before applying the updates
+ """
+ raise NotImplementedError()
+
+ def __init__(self, stream_callbacks, loop):
+ """
+ Args:
+ stream_callbacks ({string: Callback}): Mapping of stream names to
+ callbacks to subscribe to.
+ loop: asyncio event loop to schedule the callback
+ """
+ threading.Thread.__init__(self)
+ self._stream_callbacks = stream_callbacks
+ self._loop = loop
+ # Set this thread as daemon thread. We can kill this background
+ # thread abruptly since we handle all updates (and database
+ # transactions) in the asyncio event loop.
+ self.daemon = True
+
+ # Don't allow stream update rate faster than every 5 seconds
+ self._reconnect_pause = get_service_config_value(
+ 'streamer', 'reconnect_sec', 60,
+ )
+ self._reconnect_pause = max(5, self._reconnect_pause)
+ logging.info("Streamer reconnect pause: %d", self._reconnect_pause)
+ self._stream_timeout = get_service_config_value(
+ 'streamer', 'stream_timeout', 150,
+ )
+ logging.info("Streamer timeout: %d", self._stream_timeout)
+
+ def run(self):
+ while True:
+ try:
+ channel = ServiceRegistry.get_rpc_channel(
+ 'streamer', ServiceRegistry.CLOUD,
+ )
+ client = StreamerStub(channel)
+ self.process_all_streams(client)
+ except Exception as exp: # pylint: disable=broad-except
+ logging.error("Error with streamer: %s", exp)
+
+ # If the connection is terminated, wait for a period of time
+ # before connecting back to the cloud.
+ # TODO: make this more intelligent (exponential backoffs, etc.)
+ time.sleep(self._reconnect_pause)
+
+ def process_all_streams(self, client):
+ for stream_name, callback in self._stream_callbacks.items():
+ try:
+ self.process_stream_updates(client, stream_name, callback)
+
+ STREAMER_RESPONSES.labels(result='Success').inc()
+ except grpc.RpcError as err:
+ logging.error(
+ "Error! Streaming from the cloud failed! [%s] %s",
+ err.code(), err.details(),
+ )
+ STREAMER_RESPONSES.labels(result='RpcError').inc()
+ except ValueError as err:
+ logging.error("Error! Streaming from cloud failed! %s", err)
+ STREAMER_RESPONSES.labels(result='ValueError').inc()
+
+ def process_stream_updates(self, client, stream_name, callback):
+ extra_args = self._get_extra_args_any(callback, stream_name)
+ request = StreamRequest(
+ gatewayId=snowflake.snowflake(),
+ stream_name=stream_name,
+ extra_args=extra_args,
+ )
+ for update_batch in client.GetUpdates(
+ request, timeout=self._stream_timeout,
+ ):
+ self._loop.call_soon_threadsafe(
+ callback.process_update,
+ stream_name,
+ update_batch.updates,
+ update_batch.resync,
+ )
+
+ @staticmethod
+ def _get_extra_args_any(callback, stream_name):
+ extra_args = callback.get_request_args(stream_name)
+ if extra_args is None:
+ return None
+ else:
+ extra_any = any_pb2.Any()
+ extra_any.Pack(extra_args)
+ return extra_any
+
+
+def get_stream_serialize_filename(stream_name):
+ return '/var/opt/magma/streams/{}'.format(stream_name)
+
+
+class SerializingStreamCallback(StreamerClient.Callback):
+ """
+ Streamer client callback which decodes stream update as a string and writes
+ it to a file, overwriting the previous contents of that file. The file
+ location is defined by get_stream_serialize_filename.
+
+ This callback will only save the newest update, with each successive update
+ overwriting the previous.
+ """
+
+ def get_request_args(self, stream_name: str) -> Any:
+ return None
+
+ def process_update(self, stream_name, updates, resync):
+ if not updates:
+ return
+ # For now, we only care about the last (newest) update
+ for update in updates[:-1]:
+ logging.info('Ignoring update %s', update.key)
+
+ logging.info('Serializing stream update %s', updates[-1].key)
+ filename = get_stream_serialize_filename(stream_name)
+ serialization_utils.write_to_file_atomically(
+ filename,
+ updates[-1].value.decode(),
+ )
diff --git a/common/tests/cert_utils_tests.py b/common/tests/cert_utils_tests.py
new file mode 100644
index 0000000..6563ff9
--- /dev/null
+++ b/common/tests/cert_utils_tests.py
@@ -0,0 +1,109 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import base64
+import datetime
+import os
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+
+import magma.common.cert_utils as cu
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes, serialization
+from cryptography.hazmat.primitives.asymmetric import ec
+
+
+class CertUtilsTest(TestCase):
+ def test_key(self):
+ with TemporaryDirectory(prefix='/tmp/test_cert_utils') as temp_dir:
+ key = ec.generate_private_key(ec.SECP384R1(), default_backend())
+ cu.write_key(key, os.path.join(temp_dir, 'test.key'))
+ key_load = cu.load_key(os.path.join(temp_dir, 'test.key'))
+
+ key_bytes = key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ serialization.NoEncryption(),
+ )
+ key_load_bytes = key_load.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.TraditionalOpenSSL,
+ serialization.NoEncryption(),
+ )
+ self.assertEqual(key_bytes, key_load_bytes)
+
+ def load_public_key_to_base64der(self):
+ with TemporaryDirectory(prefix='/tmp/test_cert_utils') as temp_dir:
+ key = ec.generate_private_key(ec.SECP384R1(), default_backend())
+ cu.write_key(key, os.path.join(temp_dir, 'test.key'))
+ base64der = cu.load_public_key_to_base64der(
+ os.path.join(temp_dir, 'test.key'),
+ )
+ der = base64.b64decode(base64der)
+ pub_key = serialization.load_der_public_key(der, default_backend())
+ self.assertEqual(pub_key, key.public_key())
+
+ def test_csr(self):
+ key = ec.generate_private_key(ec.SECP384R1(), default_backend())
+ csr = cu.create_csr(
+ key, 'i am dummy test',
+ 'US', 'CA', 'MPK', 'FB', 'magma', 'magma@fb.com',
+ )
+ self.assertTrue(csr.is_signature_valid)
+ public_key_bytes = key.public_key().public_bytes(
+ serialization.Encoding.OpenSSH,
+ serialization.PublicFormat.OpenSSH,
+ )
+ csr_public_key_bytes = csr.public_key().public_bytes(
+ serialization.Encoding.OpenSSH,
+ serialization.PublicFormat.OpenSSH,
+ )
+ self.assertEqual(public_key_bytes, csr_public_key_bytes)
+
+ def test_cert(self):
+ with TemporaryDirectory(prefix='/tmp/test_cert_utils') as temp_dir:
+ cert = _create_dummy_cert()
+ cert_file = os.path.join(temp_dir, 'test.cert')
+ cu.write_cert(
+ cert.public_bytes(
+ serialization.Encoding.DER,
+ ), cert_file,
+ )
+ cert_load = cu.load_cert(cert_file)
+ self.assertEqual(cert, cert_load)
+
+
+def _create_dummy_cert():
+ key = ec.generate_private_key(ec.SECP384R1(), default_backend())
+ subject = issuer = x509.Name([
+ x509.NameAttribute(x509.oid.NameOID.COUNTRY_NAME, u"US"),
+ x509.NameAttribute(x509.oid.NameOID.STATE_OR_PROVINCE_NAME, u"CA"),
+ x509.NameAttribute(x509.oid.NameOID.LOCALITY_NAME, u"San Francisco"),
+ x509.NameAttribute(x509.oid.NameOID.ORGANIZATION_NAME, u"My Company"),
+ x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, u"mysite.com"),
+ ])
+ cert = x509.CertificateBuilder().subject_name(
+ subject,
+ ).issuer_name(
+ issuer,
+ ).public_key(
+ key.public_key(),
+ ).serial_number(
+ x509.random_serial_number(),
+ ).not_valid_before(
+ datetime.datetime.utcnow(),
+ ).not_valid_after(
+ datetime.datetime.utcnow() + datetime.timedelta(days=10),
+ ).sign(key, hashes.SHA256(), default_backend())
+ return cert
diff --git a/common/tests/cert_validity_tests.py b/common/tests/cert_validity_tests.py
new file mode 100644
index 0000000..8ce67dc
--- /dev/null
+++ b/common/tests/cert_validity_tests.py
@@ -0,0 +1,285 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+import errno
+import ssl
+from unittest import TestCase
+from unittest.mock import MagicMock, patch
+
+import magma.common.cert_validity as cv
+
+
+# https://stackoverflow.com/questions/32480108/mocking-async-call-in-python-3-5
+def AsyncMock():
+ coro = MagicMock(name="CoroutineResult")
+ corofunc = MagicMock(
+ name="CoroutineFunction",
+ side_effect=asyncio.coroutine(coro),
+ )
+ corofunc.coro = coro
+ return corofunc
+
+
+class CertValidityTests(TestCase):
+ def setUp(self):
+ self.host = 'localhost'
+ self.port = 8080
+ self.certfile = 'certfile'
+ self.keyfile = 'keyfile'
+
+ asyncio.set_event_loop(None)
+ self.loop = asyncio.new_event_loop()
+
+ def test_tcp_connection(self):
+ """
+ Test that loop.create_connection called with the correct TCP args.
+ """
+ self.loop.create_connection = MagicMock()
+
+ @asyncio.coroutine
+ def go():
+ yield from cv.create_tcp_connection(
+ self.host,
+ self.port,
+ self.loop,
+ )
+ self.loop.run_until_complete(go())
+
+ self.loop.create_connection.assert_called_once_with(
+ cv.TCPClientProtocol,
+ self.host,
+ self.port,
+ )
+
+ @patch('magma.common.cert_validity.ssl.SSLContext')
+ def test_ssl_connection(self, mock_ssl):
+ """
+ Test that ssl.SSLContext and loop.create_connection are called with the
+ correct SSL args.
+ """
+ self.loop.create_connection = MagicMock()
+
+ @asyncio.coroutine
+ def go():
+ yield from cv.create_ssl_connection(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ self.loop.run_until_complete(go())
+
+ mock_context = mock_ssl.return_value
+
+ mock_ssl.assert_called_once_with(ssl.PROTOCOL_SSLv23)
+ mock_context.load_cert_chain.assert_called_once_with(
+ self.certfile,
+ keyfile=self.keyfile,
+ )
+
+ self.loop.create_connection.assert_called_once_with(
+ cv.TCPClientProtocol,
+ self.host,
+ self.port,
+ ssl=mock_context,
+ )
+
+ @patch(
+ 'magma.common.cert_validity.create_ssl_connection',
+ new_callable=AsyncMock,
+ )
+ @patch(
+ 'magma.common.cert_validity.create_tcp_connection',
+ new_callable=AsyncMock,
+ )
+ def test_cert_is_invalid_both_ok(self, mock_create_tcp, mock_create_ssl):
+ """
+ Test the appropriate calls and return value for cert_is_invalid()
+ cert_is_invalid() == False when TCP and SSL succeed
+ """
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+
+ mock_create_tcp.assert_called_once_with(
+ self.host,
+ self.port,
+ self.loop,
+ )
+ mock_create_ssl.assert_called_once_with(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ self.assertEqual(ret_val, False)
+
+ @patch(
+ 'magma.common.cert_validity.create_ssl_connection',
+ new_callable=AsyncMock,
+ )
+ @patch('magma.common.cert_validity.create_tcp_connection', AsyncMock())
+ def test_cert_is_invalid_ssl_fail(self, mock_create_ssl):
+ """
+ Test cert_is_invalid() == True when TCP succeeds and SSL fails
+ """
+
+ mock_err = TimeoutError()
+ mock_err.errno = errno.ETIMEDOUT
+ mock_create_ssl.coro.side_effect = mock_err
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+ self.assertEqual(ret_val, True)
+
+ @patch(
+ 'magma.common.cert_validity.create_ssl_connection',
+ new_callable=AsyncMock,
+ )
+ @patch('magma.common.cert_validity.create_tcp_connection', AsyncMock())
+ def test_cert_is_invalid_ssl_fail_none_errno(self, mock_create_ssl):
+ """
+ Test cert_is_invalid() == True when TCP succeeds and SSL fails w/o error number
+ """
+
+ mock_err = TimeoutError()
+ mock_err.errno = None
+ mock_create_ssl.coro.side_effect = mock_err
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+ self.assertEqual(ret_val, True)
+
+ @patch('magma.common.cert_validity.create_ssl_connection', AsyncMock())
+ @patch(
+ 'magma.common.cert_validity.create_tcp_connection',
+ new_callable=AsyncMock,
+ )
+ def test_cert_is_invalid_tcp_fail_none_errno(self, mock_create_tcp):
+ """
+ Test cert_is_invalid() == False when TCP fails w/o errno and SSL succeeds
+ """
+
+ mock_err = TimeoutError()
+ mock_err.errno = None
+ mock_create_tcp.coro.side_effect = mock_err
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+ self.assertEqual(ret_val, False)
+
+ @patch('magma.common.cert_validity.create_ssl_connection', AsyncMock())
+ @patch(
+ 'magma.common.cert_validity.create_tcp_connection',
+ new_callable=AsyncMock,
+ )
+ def test_cert_is_invalid_tcp_fail(self, mock_create_tcp):
+ """
+ Test cert_is_invalid() == False when TCP fails and SSL succeeds
+ """
+
+ mock_err = TimeoutError()
+ mock_err.errno = errno.ETIMEDOUT
+ mock_create_tcp.coro.side_effect = mock_err
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+ self.assertEqual(ret_val, False)
+
+ @patch(
+ 'magma.common.cert_validity.create_ssl_connection',
+ new_callable=AsyncMock,
+ )
+ @patch(
+ 'magma.common.cert_validity.create_tcp_connection',
+ new_callable=AsyncMock,
+ )
+ def test_cert_is_invalid_both_fail(self, mock_create_tcp, mock_create_ssl):
+ """
+ Test cert_is_invalid() == False when TCP and SSL fail
+ """
+
+ mock_tcp_err = TimeoutError()
+ mock_tcp_err.errno = errno.ETIMEDOUT
+ mock_create_tcp.coro.side_effect = mock_tcp_err
+
+ mock_ssl_err = TimeoutError()
+ mock_ssl_err.errno = errno.ETIMEDOUT
+ mock_create_ssl.coro.side_effect = mock_ssl_err
+
+ @asyncio.coroutine
+ def go():
+ return (
+ yield from cv.cert_is_invalid(
+ self.host,
+ self.port,
+ self.certfile,
+ self.keyfile,
+ self.loop,
+ )
+ )
+ ret_val = self.loop.run_until_complete(go())
+ self.assertEqual(ret_val, False)
diff --git a/common/tests/metrics_tests.py b/common/tests/metrics_tests.py
new file mode 100644
index 0000000..f48f2f8
--- /dev/null
+++ b/common/tests/metrics_tests.py
@@ -0,0 +1,241 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+import unittest.mock
+
+import metrics_pb2
+from common import metrics_export
+from orc8r.protos import metricsd_pb2
+from prometheus_client import (
+ CollectorRegistry,
+ Counter,
+ Gauge,
+ Histogram,
+ Summary,
+)
+
+
+class Service303MetricTests(unittest.TestCase):
+ """
+ Tests for the Service303 metrics interface
+ """
+
+ def setUp(self):
+ self.registry = CollectorRegistry()
+ self.maxDiff = None
+
+ def test_counter(self):
+ """Test that we can track counters in Service303"""
+ # Add a counter with a label to the regisry
+ c = Counter(
+ 'process_max_fds', 'A counter', ['result'],
+ registry=self.registry,
+ )
+
+ # Create two series for value1 and value2
+ c.labels('success').inc(1.23)
+ c.labels('failure').inc(2.34)
+
+ # Build proto outputs
+ counter1 = metrics_pb2.Counter(value=1.23)
+ counter2 = metrics_pb2.Counter(value=2.34)
+ metric1 = metrics_pb2.Metric(
+ counter=counter1,
+ timestamp_ms=1234000,
+ )
+ metric2 = metrics_pb2.Metric(
+ counter=counter2,
+ timestamp_ms=1234000,
+ )
+ family = metrics_pb2.MetricFamily(
+ name=str(metricsd_pb2.process_max_fds),
+ type=metrics_pb2.COUNTER,
+ )
+ metric1.label.add(
+ name=str(metricsd_pb2.result),
+ value='success',
+ )
+ metric2.label.add(
+ name=str(metricsd_pb2.result),
+ value='failure',
+ )
+ family.metric.extend([metric1, metric2])
+
+ with unittest.mock.patch('time.time') as mock_time:
+ mock_time.side_effect = lambda: 1234
+ self.assertCountEqual(
+ list(metrics_export.get_metrics(self.registry))[0].metric,
+ family.metric,
+ )
+
+ def test_gauge(self):
+ """Test that we can track gauges in Service303"""
+ # Add a gauge with a label to the regisry
+ c = Gauge(
+ 'process_max_fds', 'A gauge', ['result'],
+ registry=self.registry,
+ )
+
+ # Create two series for value1 and value2
+ c.labels('success').inc(1.23)
+ c.labels('failure').inc(2.34)
+
+ # Build proto outputs
+ gauge1 = metrics_pb2.Gauge(value=1.23)
+ gauge2 = metrics_pb2.Gauge(value=2.34)
+ metric1 = metrics_pb2.Metric(
+ gauge=gauge1,
+ timestamp_ms=1234000,
+ )
+ metric2 = metrics_pb2.Metric(
+ gauge=gauge2,
+ timestamp_ms=1234000,
+ )
+ family = metrics_pb2.MetricFamily(
+ name=str(metricsd_pb2.process_max_fds),
+ type=metrics_pb2.GAUGE,
+ )
+ metric1.label.add(
+ name=str(metricsd_pb2.result),
+ value='success',
+ )
+ metric2.label.add(
+ name=str(metricsd_pb2.result),
+ value='failure',
+ )
+ family.metric.extend([metric1, metric2])
+
+ with unittest.mock.patch('time.time') as mock_time:
+ mock_time.side_effect = lambda: 1234
+ self.assertCountEqual(
+ list(metrics_export.get_metrics(self.registry))[0].metric,
+ family.metric,
+ )
+
+ def test_summary(self):
+ """Test that we can track summaries in Service303"""
+ # Add a summary with a label to the regisry
+ c = Summary(
+ 'process_max_fds', 'A summary', [
+ 'result',
+ ], registry=self.registry,
+ )
+ c.labels('success').observe(1.23)
+ c.labels('failure').observe(2.34)
+
+ # Build proto outputs
+ summary1 = metrics_pb2.Summary(sample_count=1, sample_sum=1.23)
+ summary2 = metrics_pb2.Summary(sample_count=1, sample_sum=2.34)
+ metric1 = metrics_pb2.Metric(
+ summary=summary1,
+ timestamp_ms=1234000,
+ )
+ metric2 = metrics_pb2.Metric(
+ summary=summary2,
+ timestamp_ms=1234000,
+ )
+ family = metrics_pb2.MetricFamily(
+ name=str(metricsd_pb2.process_max_fds),
+ type=metrics_pb2.SUMMARY,
+ )
+ metric1.label.add(
+ name=str(metricsd_pb2.result),
+ value='success',
+ )
+ metric2.label.add(
+ name=str(metricsd_pb2.result),
+ value='failure',
+ )
+ family.metric.extend([metric1, metric2])
+
+ with unittest.mock.patch('time.time') as mock_time:
+ mock_time.side_effect = lambda: 1234
+ self.assertCountEqual(
+ list(metrics_export.get_metrics(self.registry))[0].metric,
+ family.metric,
+ )
+
+ def test_histogram(self):
+ """Test that we can track histogram in Service303"""
+ # Add a histogram with a label to the regisry
+ c = Histogram(
+ 'process_max_fds', 'A summary', ['result'],
+ registry=self.registry, buckets=[0, 2, float('inf')],
+ )
+ c.labels('success').observe(1.23)
+ c.labels('failure').observe(2.34)
+
+ # Build proto outputs
+ histogram1 = metrics_pb2.Histogram(sample_count=1, sample_sum=1.23)
+ histogram1.bucket.add(upper_bound=0, cumulative_count=0)
+ histogram1.bucket.add(upper_bound=2, cumulative_count=1)
+ histogram1.bucket.add(upper_bound=float('inf'), cumulative_count=1)
+ histogram2 = metrics_pb2.Histogram(sample_count=1, sample_sum=2.34)
+ histogram2.bucket.add(upper_bound=0, cumulative_count=0)
+ histogram2.bucket.add(upper_bound=2, cumulative_count=0)
+ histogram2.bucket.add(upper_bound=float('inf'), cumulative_count=1)
+ metric1 = metrics_pb2.Metric(
+ histogram=histogram1,
+ timestamp_ms=1234000,
+ )
+ metric2 = metrics_pb2.Metric(
+ histogram=histogram2,
+ timestamp_ms=1234000,
+ )
+ family = metrics_pb2.MetricFamily(
+ name=str(metricsd_pb2.process_max_fds),
+ type=metrics_pb2.HISTOGRAM,
+ )
+ metric1.label.add(
+ name=str(metricsd_pb2.result),
+ value='success',
+ )
+ metric2.label.add(
+ name=str(metricsd_pb2.result),
+ value='failure',
+ )
+ family.metric.extend([metric1, metric2])
+
+ with unittest.mock.patch('time.time') as mock_time:
+ mock_time.side_effect = lambda: 1234
+ self.assertCountEqual(
+ list(metrics_export.get_metrics(self.registry))[0].metric,
+ family.metric,
+ )
+
+ def test_converted_enums(self):
+ """ Test that metric names and labels are auto converted """
+ # enum values (from metricsd.proto):
+ # mme_new_association => 500, result => 0
+ c = Counter(
+ 'mme_new_association', 'A counter', ['result'],
+ registry=self.registry,
+ )
+
+ c.labels('success').inc(1.23)
+
+ metric_family = list(metrics_export.get_metrics(self.registry))[0]
+
+ self.assertEqual(
+ metric_family.name,
+ str(metricsd_pb2.mme_new_association),
+ )
+ metric_labels = metric_family.metric[0].label
+ # Order not guaranteed=
+ self.assertEqual(metric_labels[0].name, str(metricsd_pb2.result))
+ self.assertEqual(metric_labels[0].value, 'success')
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/common/tests/service303_tests.py b/common/tests/service303_tests.py
new file mode 100644
index 0000000..49175e4
--- /dev/null
+++ b/common/tests/service303_tests.py
@@ -0,0 +1,90 @@
+"""
+Copyright 2020 The Magma Authors.
+
+This source code is licensed under the BSD-style license found in the
+LICENSE file in the root directory of this source tree.
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import asyncio
+from unittest import TestCase, main, mock
+
+from common.service import MagmaService
+from common.service_registry import ServiceRegistry
+from orc8r.protos.common_pb2 import Void
+from orc8r.protos.mconfig import mconfigs_pb2
+from orc8r.protos.service303_pb2 import ServiceInfo
+from orc8r.protos.service303_pb2_grpc import Service303Stub
+
+
+class Service303Tests(TestCase):
+ """
+ Tests for the MagmaService and the Service303 interface
+ """
+
+ @mock.patch('time.time', mock.MagicMock(return_value=12345))
+ def setUp(self):
+ ServiceRegistry.add_service('test', '0.0.0.0', 0)
+ self._stub = None
+
+ self._loop = asyncio.new_event_loop()
+ # Use a new event loop to ensure isolated tests
+ self._service = MagmaService(
+ name='test',
+ empty_mconfig=mconfigs_pb2.MagmaD(),
+ loop=self._loop,
+ )
+ asyncio.set_event_loop(self._service.loop)
+
+ @mock.patch(
+ 'magma.common.service_registry.ServiceRegistry.get_proxy_config',
+ )
+ def test_service_run(self, mock_get_proxy_config):
+ """
+ Test if the service starts and stops gracefully.
+ """
+
+ self.assertEqual(self._service.state, ServiceInfo.STARTING)
+
+ mock_get_proxy_config.return_value = {
+ 'cloud_address': '127.0.0.1',
+ 'proxy_cloud_connections': True,
+ }
+
+ # Start the service and pause the loop
+ self._service.loop.stop()
+ self._service.run()
+ asyncio.set_event_loop(self._service.loop)
+ self._service.log_counter._periodic_task.cancel()
+ self.assertEqual(self._service.state, ServiceInfo.ALIVE)
+
+ # Create a rpc stub and query the Service303 interface
+ ServiceRegistry.add_service('test', '0.0.0.0', self._service.port)
+ channel = ServiceRegistry.get_rpc_channel(
+ 'test',
+ ServiceRegistry.LOCAL,
+ )
+ self._stub = Service303Stub(channel)
+
+ info = ServiceInfo(
+ name='test',
+ version='0.0.0',
+ state=ServiceInfo.ALIVE,
+ health=ServiceInfo.APP_HEALTHY,
+ start_time_secs=12345,
+ )
+ self.assertEqual(self._stub.GetServiceInfo(Void()), info)
+
+ # Stop the service
+ self._stub.StopService(Void())
+ self._service.loop.run_forever()
+ self.assertEqual(self._service.state, ServiceInfo.STOPPED)
+
+
+if __name__ == "__main__":
+ main()