blob: 1d1341cd09a47c56b5d84acfa89b1d7b602edf27 [file] [log] [blame]
Chip Bolingf5af85d2019-02-12 15:36:17 -06001# Copyright 2017-present Adtran, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import sys
16import structlog
17
18from twisted.internet.defer import succeed
19from twisted.internet import threads
20
21from txzmq import ZmqEndpoint, ZmqFactory
22from txzmq.connection import ZmqConnection
23
24import zmq
25from zmq import constants
26from zmq.utils import jsonapi
27from zmq.utils.strtypes import b, u
28from zmq.auth.base import Authenticator
29
30from threading import Thread, Event
31
32zmq_factory = ZmqFactory()
33
34
35class AdtranZmqClient(object):
36 """
37 Adtran ZeroMQ Client for PON Agent and/or packet in/out service
38 """
39 def __init__(self, ip_address, rx_callback, port):
40 self.log = structlog.get_logger()
41
42 external_conn = 'tcp://{}:{}'.format(ip_address, port)
43
44 self.zmq_endpoint = ZmqEndpoint('connect', external_conn)
45 self._socket = ZmqPairConnection(zmq_factory, self.zmq_endpoint)
46 self._socket.onReceive = rx_callback or AdtranZmqClient.rx_nop
47 self.auth = None
48
49 def send(self, data):
50 try:
51 self._socket.send(data)
52
53 except Exception as e:
54 self.log.exception('send', e=e)
55
56 def shutdown(self):
57 self._socket.onReceive = AdtranZmqClient.rx_nop
58 self._socket.shutdown()
59
60 @property
61 def socket(self):
62 return self._socket
63
64 @staticmethod
65 def rx_nop(_):
66 pass
67
68 def setup_plain_security(self, username, password):
69 self.log.debug('setup-plain-security')
70
71 def configure_plain(_):
72 self.log.debug('plain-security', username=username,
73 password=password)
74
75 self.auth.configure_plain(domain='*', passwords={username: password})
76 self._socket.socket.plain_username = username
77 self._socket.socket.plain_password = password
78
79 def add_endoints(_results):
80 self._socket.addEndpoints([self.zmq_endpoint])
81
82 def config_failure(_results):
83 raise Exception('Failed to configure plain-text security')
84
85 def endpoint_failure(_results):
86 raise Exception('Failed to complete endpoint setup')
87
88 self.auth = TwistedZmqAuthenticator()
89
90 d = self.auth.start()
91 d.addCallbacks(configure_plain, config_failure)
92 d.addCallbacks(add_endoints, endpoint_failure)
93
94 return d
95
96 def setup_curve_security(self):
97 self.log.debug('setup-curve-security')
98 raise NotImplementedError('TODO: curve transport security is not yet supported')
99
100
101class ZmqPairConnection(ZmqConnection):
102 """
103 Bidirectional messages to/from the socket.
104
105 Wrapper around ZeroMQ PUSH socket.
106 """
107 socketType = constants.PAIR
108
109 def messageReceived(self, message):
110 """
111 Called on incoming message from ZeroMQ.
112
113 :param message: message data
114 """
115 self.onReceive(message)
116
117 def onReceive(self, message):
118 """
119 Called on incoming message received from other end of the pair.
120
121 :param message: message data
122 """
123 raise NotImplementedError(self)
124
125 def send(self, message):
126 """
127 Send message via ZeroMQ socket.
128
129 Sending is performed directly to ZeroMQ without queueing. If HWM is
130 reached on ZeroMQ side, sending operation is aborted with exception
131 from ZeroMQ (EAGAIN).
132
133 After writing read is scheduled as ZeroMQ may not signal incoming
134 messages after we touched socket with write request.
135
136 :param message: message data, could be either list of str (multipart
137 message) or just str
138 :type message: str or list of str
139 """
140 from txzmq.compat import is_nonstr_iter
141 from twisted.internet import reactor
142
143 if not is_nonstr_iter(message):
144 self.socket.send(message, constants.NOBLOCK)
145 else:
146 # for m in message[:-1]:
147 # self.socket.send(m, constants.NOBLOCK | constants.SNDMORE)
148 # self.socket.send(message[-1], constants.NOBLOCK)
149 self.socket.send_multipart(message, flags=constants.NOBLOCK)
150
151 if self.read_scheduled is None:
152 self.read_scheduled = reactor.callLater(0, self.doRead)
153
154###############################################################################################
155###############################################################################################
156###############################################################################################
157###############################################################################################
158
159
160def _inherit_docstrings(cls):
161 """inherit docstrings from Authenticator, so we don't duplicate them"""
162 for name, method in cls.__dict__.items():
163 if name.startswith('_'):
164 continue
165 upstream_method = getattr(Authenticator, name, None)
166 if not method.__doc__:
167 method.__doc__ = upstream_method.__doc__
168 return cls
169
170
171@_inherit_docstrings
172class TwistedZmqAuthenticator(object):
173 """Run ZAP authentication in a background thread but communicate via Twisted ZMQ"""
174
175 def __init__(self, encoding='utf-8'):
176 self.log = structlog.get_logger()
177 self.context = zmq_factory.context
178 self.encoding = encoding
179 self.pipe = None
180 self.pipe_endpoint = "inproc://{0}.inproc".format(id(self))
181 self.thread = None
182
183 def allow(self, *addresses):
184 try:
185 self.pipe.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
186
187 except Exception as e:
188 self.log.exception('allow', e=e)
189
190 def deny(self, *addresses):
191 try:
192 self.pipe.send([b'DENY'] + [b(a, self.encoding) for a in addresses])
193
194 except Exception as e:
195 self.log.exception('deny', e=e)
196
197 def configure_plain(self, domain='*', passwords=None):
198 try:
199 self.pipe.send([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})])
200
201 except Exception as e:
202 self.log.exception('configure-plain', e=e)
203
204 def configure_curve(self, domain='*', location=''):
205 try:
206 domain = b(domain, self.encoding)
207 location = b(location, self.encoding)
208 self.pipe.send([b'CURVE', domain, location])
209
210 except Exception as e:
211 self.log.exception('configure-curve', e=e)
212
213 def start(self, rx_callback=AdtranZmqClient.rx_nop):
214 """Start the authentication thread"""
215 try:
216 # create a socket to communicate with auth thread.
217
218 endpoint = ZmqEndpoint('bind', self.pipe_endpoint) # We are server, thread will be client
219 self.pipe = ZmqPairConnection(zmq_factory, endpoint)
220 self.pipe.onReceive = rx_callback
221
222 self.thread = LocalAuthenticationThread(self.context,
223 self.pipe_endpoint,
224 encoding=self.encoding)
225
226 return threads.deferToThread(TwistedZmqAuthenticator._do_thread_start,
227 self.thread, timeout=10)
228
229 except Exception as e:
230 self.log.exception('start', e=e)
231
232 @staticmethod
233 def _do_thread_start(thread, timeout=10):
234 thread.start()
235
236 # Event.wait:Changed in version 2.7: Previously, the method always returned None.
237 if sys.version_info < (2, 7):
238 thread.started.wait(timeout=timeout)
239
240 elif not thread.started.wait(timeout=timeout):
241 raise RuntimeError("Authenticator thread failed to start")
242
243 def stop(self):
244 """Stop the authentication thread"""
245 pipe, self.pipe = self.pipe, None
246 thread, self.thread = self.thread, None
247
248 if pipe:
249 pipe.send(b'TERMINATE')
250 pipe.onReceive = AdtranZmqClient.rx_nop
251 pipe.shutdown()
252
253 if thread.is_alive():
254 return threads.deferToThread(TwistedZmqAuthenticator._do_thread_join,
255 thread)
256 return succeed('done')
257
258 @staticmethod
259 def _do_thread_join(thread, timeout=1):
260 thread.join(timeout)
261 pass
262
263 def is_alive(self):
264 """Is the ZAP thread currently running?"""
265 return self.thread and self.thread.is_alive()
266
267 def __del__(self):
268 self.stop()
269
270
271# NOTE: Following is a duplicated from zmq code since class was not exported
272class LocalAuthenticationThread(Thread):
273 """A Thread for running a zmq Authenticator
274
275 This is run in the background by ThreadedAuthenticator
276 """
277
278 def __init__(self, context, endpoint, encoding='utf-8', authenticator=None):
279 super(LocalAuthenticationThread, self).__init__(name='0mq Authenticator')
280 self.log = structlog.get_logger()
281 self.context = context or zmq.Context.instance()
282 self.encoding = encoding
283 self.started = Event()
284 self.authenticator = authenticator or Authenticator(context, encoding=encoding)
285
286 # create a socket to communicate back to main thread.
287 self.pipe = context.socket(zmq.PAIR)
288 self.pipe.linger = 1
289 self.pipe.connect(endpoint)
290
291 def run(self):
292 """Start the Authentication Agent thread task"""
293 try:
294 self.authenticator.start()
295 self.started.set()
296 zap = self.authenticator.zap_socket
297 poller = zmq.Poller()
298 poller.register(self.pipe, zmq.POLLIN)
299 poller.register(zap, zmq.POLLIN)
300 while True:
301 try:
302 socks = dict(poller.poll())
303 except zmq.ZMQError:
304 break # interrupted
305
306 if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
307 terminate = self._handle_pipe()
308 if terminate:
309 break
310
311 if zap in socks and socks[zap] == zmq.POLLIN:
312 self._handle_zap()
313
314 self.pipe.close()
315 self.authenticator.stop()
316
317 except Exception as e:
318 self.log.exception("run", e=e)
319
320 def _handle_zap(self):
321 """
322 Handle a message from the ZAP socket.
323 """
324 msg = self.authenticator.zap_socket.recv_multipart()
325 if not msg:
326 return
327 self.authenticator.handle_zap_message(msg)
328
329 def _handle_pipe(self):
330 """
331 Handle a message from front-end API.
332 """
333 terminate = False
334
335 # Get the whole message off the pipe in one go
336 msg = self.pipe.recv_multipart()
337
338 if msg is None:
339 terminate = True
340 return terminate
341
342 command = msg[0]
343 self.log.debug("auth received API command", command=command)
344
345 if command == b'ALLOW':
346 addresses = [u(m, self.encoding) for m in msg[1:]]
347 try:
348 self.authenticator.allow(*addresses)
349 except Exception as e:
350 self.log.exception("Failed to allow", addresses=addresses, e=e)
351
352 elif command == b'DENY':
353 addresses = [u(m, self.encoding) for m in msg[1:]]
354 try:
355 self.authenticator.deny(*addresses)
356 except Exception as e:
357 self.log.exception("Failed to deny", addresses=addresses, e=e)
358
359 elif command == b'PLAIN':
360 domain = u(msg[1], self.encoding)
361 json_passwords = msg[2]
362 self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords))
363
364 elif command == b'CURVE':
365 # For now we don't do anything with domains
366 domain = u(msg[1], self.encoding)
367
368 # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
369 # treat location as a directory that holds the certificates.
370 location = u(msg[2], self.encoding)
371 self.authenticator.configure_curve(domain, location)
372
373 elif command == b'TERMINATE':
374 terminate = True
375
376 else:
377 self.log.error("Invalid auth command from API", command=command)
378
379 return terminate