blob: b3306b9ce8954f987ac2a5876478d93616ac9ce5 [file] [log] [blame]
Wei-Yu Chenad55cb82022-02-15 20:07:01 +08001# SPDX-FileCopyrightText: 2020 The Magma Authors.
2# SPDX-FileCopyrightText: 2022 Open Networking Foundation <support@opennetworking.org>
3#
4# SPDX-License-Identifier: BSD-3-Clause
Wei-Yu Chen49950b92021-11-08 19:19:18 +08005
Wei-Yu Chen49950b92021-11-08 19:19:18 +08006# pylint: disable=broad-except
7
8import asyncio
9import logging
10from enum import Enum
11
12import grpc
13from google.protobuf import message as proto_message
14from google.protobuf.json_format import MessageToJson
15from common.service_registry import ServiceRegistry
16from orc8r.protos import common_pb2
17
18
19class RetryableGrpcErrorDetails(Enum):
20 """
21 Enum for gRPC retryable error detail messages
22 """
23 SOCKET_CLOSED = "Socket closed"
24 CONNECT_FAILED = "Connect Failed"
25
26
27def return_void(func):
28 """
29 Reusable decorator for returning common_pb2.Void() message.
30 """
31
32 def wrapper(*args, **kwargs):
33 func(*args, **kwargs)
34 return common_pb2.Void()
35
36 return wrapper
37
38
39def grpc_wrapper(func):
40 """
41 Wraps a function with a gRPC wrapper which creates a RPC client to
42 the service and handles any RPC Exceptions.
43
44 Usage:
45 @grpc_wrapper
46 def func(client, args):
47 pass
48 func(args, ProtoStubClass, 'service')
49 """
50
51 def wrapper(*alist):
52 args = alist[0]
53 stub_cls = alist[1]
54 service = alist[2]
55 chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.LOCAL)
56 client = stub_cls(chan)
57 try:
58 func(client, args)
59 except grpc.RpcError as err:
60 print("Error! [%s] %s" % (err.code(), err.details()))
61 exit(1)
62
63 return wrapper
64
65
66def cloud_grpc_wrapper(func):
67 """
68 Wraps a function with a gRPC wrapper which creates a RPC client to
69 the service and handles any RPC Exceptions.
70
71 Usage:
72 @cloud_grpc_wrapper
73 def func(client, args):
74 pass
75 func(args, ProtoStubClass, 'service')
76 """
77
78 def wrapper(*alist):
79 args = alist[0]
80 stub_cls = alist[1]
81 service = alist[2]
82 chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.CLOUD)
83 client = stub_cls(chan)
84 try:
85 func(client, args)
86 except grpc.RpcError as err:
87 print("Error! [%s] %s" % (err.code(), err.details()))
88 exit(1)
89
90 return wrapper
91
92
93def set_grpc_err(
94 context: grpc.ServicerContext,
95 code: grpc.StatusCode,
96 details: str,
97):
98 """
99 Sets status code and details for a gRPC context. Removes commas from
100 the details message (see https://github.com/grpc/grpc-node/issues/769)
101 """
102 context.set_code(code)
103 context.set_details(details.replace(',', ''))
104
105
106def _grpc_async_wrapper(f, gf):
107 try:
108 f.set_result(gf.result())
109 except Exception as e:
110 f.set_exception(e)
111
112
113def grpc_async_wrapper(gf, loop=None):
114 """
115 Wraps a GRPC result in a future that can be yielded by asyncio
116
117 Usage:
118
119 async def my_fn(param):
120 result =
121 await grpc_async_wrapper(stub.function_name.future(param, timeout))
122
123 Code taken and modified from:
124 https://github.com/grpc/grpc/wiki/Integration-with-tornado-(python)
125 """
126 f = asyncio.Future()
127 if loop is None:
128 loop = asyncio.get_event_loop()
129 gf.add_done_callback(
130 lambda _: loop.call_soon_threadsafe(_grpc_async_wrapper, f, gf),
131 )
132 return f
133
134
135def is_grpc_error_retryable(error: grpc.RpcError) -> bool:
136 status_code = error.code()
137 error_details = error.details()
138 if status_code == grpc.StatusCode.UNAVAILABLE and \
139 any(
140 err_msg.value in error_details for err_msg in
141 RetryableGrpcErrorDetails
142 ):
143 # server end closed connection.
144 return True
145 return False
146
147
148def print_grpc(
149 message: proto_message.Message, print_grpc_payload: bool,
150 message_header: str = "",
151):
152 """
153 Prints content of grpc message
154
155 Args:
156 message: grpc message to print
157 print_grpc_payload: flag to enable/disable printing of the message
158 message_header: header to print before printing grpc content
159 """
160
161 if print_grpc_payload:
162 log_msg = "{} {}".format(
163 message.DESCRIPTOR.full_name,
164 MessageToJson(message),
165 )
166 # add indentation
167 padding = 2 * ' '
168 log_msg = ''.join(
169 "{}{}".format(padding, line)
170 for line in log_msg.splitlines(True)
171 )
172 log_msg = "GRPC message:\n{}".format(log_msg)
173
174 if message_header:
175 logging.info(message_header)
176 logging.info(log_msg)