blob: 5fcbbf9d7fc82b042f25b12d0dca3a95dddbc492 [file] [log] [blame]
Wei-Yu Chen49950b92021-11-08 19:19:18 +08001"""
2Copyright 2020 The Magma Authors.
3
4This source code is licensed under the BSD-style license found in the
5LICENSE file in the root directory of this source tree.
6
7Unless required by applicable law or agreed to in writing, software
8distributed under the License is distributed on an "AS IS" BASIS,
9WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10See the License for the specific language governing permissions and
11limitations under the License.
12"""
13# pylint: disable=broad-except
14
15import asyncio
16import logging
17from enum import Enum
18
19import grpc
20from google.protobuf import message as proto_message
21from google.protobuf.json_format import MessageToJson
22from common.service_registry import ServiceRegistry
23from orc8r.protos import common_pb2
24
25
26class RetryableGrpcErrorDetails(Enum):
27 """
28 Enum for gRPC retryable error detail messages
29 """
30 SOCKET_CLOSED = "Socket closed"
31 CONNECT_FAILED = "Connect Failed"
32
33
34def return_void(func):
35 """
36 Reusable decorator for returning common_pb2.Void() message.
37 """
38
39 def wrapper(*args, **kwargs):
40 func(*args, **kwargs)
41 return common_pb2.Void()
42
43 return wrapper
44
45
46def grpc_wrapper(func):
47 """
48 Wraps a function with a gRPC wrapper which creates a RPC client to
49 the service and handles any RPC Exceptions.
50
51 Usage:
52 @grpc_wrapper
53 def func(client, args):
54 pass
55 func(args, ProtoStubClass, 'service')
56 """
57
58 def wrapper(*alist):
59 args = alist[0]
60 stub_cls = alist[1]
61 service = alist[2]
62 chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.LOCAL)
63 client = stub_cls(chan)
64 try:
65 func(client, args)
66 except grpc.RpcError as err:
67 print("Error! [%s] %s" % (err.code(), err.details()))
68 exit(1)
69
70 return wrapper
71
72
73def cloud_grpc_wrapper(func):
74 """
75 Wraps a function with a gRPC wrapper which creates a RPC client to
76 the service and handles any RPC Exceptions.
77
78 Usage:
79 @cloud_grpc_wrapper
80 def func(client, args):
81 pass
82 func(args, ProtoStubClass, 'service')
83 """
84
85 def wrapper(*alist):
86 args = alist[0]
87 stub_cls = alist[1]
88 service = alist[2]
89 chan = ServiceRegistry.get_rpc_channel(service, ServiceRegistry.CLOUD)
90 client = stub_cls(chan)
91 try:
92 func(client, args)
93 except grpc.RpcError as err:
94 print("Error! [%s] %s" % (err.code(), err.details()))
95 exit(1)
96
97 return wrapper
98
99
100def set_grpc_err(
101 context: grpc.ServicerContext,
102 code: grpc.StatusCode,
103 details: str,
104):
105 """
106 Sets status code and details for a gRPC context. Removes commas from
107 the details message (see https://github.com/grpc/grpc-node/issues/769)
108 """
109 context.set_code(code)
110 context.set_details(details.replace(',', ''))
111
112
113def _grpc_async_wrapper(f, gf):
114 try:
115 f.set_result(gf.result())
116 except Exception as e:
117 f.set_exception(e)
118
119
120def grpc_async_wrapper(gf, loop=None):
121 """
122 Wraps a GRPC result in a future that can be yielded by asyncio
123
124 Usage:
125
126 async def my_fn(param):
127 result =
128 await grpc_async_wrapper(stub.function_name.future(param, timeout))
129
130 Code taken and modified from:
131 https://github.com/grpc/grpc/wiki/Integration-with-tornado-(python)
132 """
133 f = asyncio.Future()
134 if loop is None:
135 loop = asyncio.get_event_loop()
136 gf.add_done_callback(
137 lambda _: loop.call_soon_threadsafe(_grpc_async_wrapper, f, gf),
138 )
139 return f
140
141
142def is_grpc_error_retryable(error: grpc.RpcError) -> bool:
143 status_code = error.code()
144 error_details = error.details()
145 if status_code == grpc.StatusCode.UNAVAILABLE and \
146 any(
147 err_msg.value in error_details for err_msg in
148 RetryableGrpcErrorDetails
149 ):
150 # server end closed connection.
151 return True
152 return False
153
154
155def print_grpc(
156 message: proto_message.Message, print_grpc_payload: bool,
157 message_header: str = "",
158):
159 """
160 Prints content of grpc message
161
162 Args:
163 message: grpc message to print
164 print_grpc_payload: flag to enable/disable printing of the message
165 message_header: header to print before printing grpc content
166 """
167
168 if print_grpc_payload:
169 log_msg = "{} {}".format(
170 message.DESCRIPTOR.full_name,
171 MessageToJson(message),
172 )
173 # add indentation
174 padding = 2 * ' '
175 log_msg = ''.join(
176 "{}{}".format(padding, line)
177 for line in log_msg.splitlines(True)
178 )
179 log_msg = "GRPC message:\n{}".format(log_msg)
180
181 if message_header:
182 logging.info(message_header)
183 logging.info(log_msg)