blob: 99cff85fad9ba62d0ae8240fa9fab19608f5651d [file] [log] [blame]
Khen Nursimulu3869d8d2016-11-28 20:44:28 -05001#
2# Copyright 2016 the original author or authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import structlog
18import sys
19from twisted.conch import avatar
20from twisted.cred import portal
21from twisted.conch.checkers import SSHPublicKeyChecker, InMemorySSHKeyDB
22from twisted.conch.ssh import factory, userauth, connection, keys, session
23from twisted.conch.ssh.transport import SSHServerTransport
24
25from twisted.cred.checkers import FilePasswordDB
26from twisted.internet import reactor
27from twisted.internet.defer import Deferred, inlineCallbacks
28# from twisted.python import log as logp
29from zope.interface import implementer
30from nc_protocol_handler import NetconfProtocolHandler
31
32from nc_connection import NetconfConnection
33
34# logp.startLogging(sys.stderr)
35
36log = structlog.get_logger()
37
38# Secure credentials directories
39# TODO: In a production environment these locations require better
40# protection. For now the user_passwords file is just a plain text file.
41KEYS_DIRECTORY = 'security/keys'
42CERTS_DIRECTORY = 'security/certificates'
43CLIENT_CRED_DIRECTORY = 'security/client_credentials'
44
45
46# @implementer(conchinterfaces.ISession)
47class NetconfAvatar(avatar.ConchUser):
48 def __init__(self, username, nc_server, grpc_stub):
49 avatar.ConchUser.__init__(self)
50 self.username = username
51 self.nc_server = nc_server
52 self.grpc_stub = grpc_stub
53 self.channelLookup.update({'session': session.SSHSession})
54 self.subsystemLookup.update(
55 {b"netconf": NetconfConnection})
56
57 def get_grpc_stub(self):
58 return self.grpc_stub
59
60 def get_nc_server(self):
61 return self.nc_server
62
63 def logout(self):
64 log.info('netconf-avatar-logout', username=self.username)
65
66
67@implementer(portal.IRealm)
68class NetconfRealm(object):
69 def __init__(self, nc_server, grpc_stub):
70 self.grpc_stub = grpc_stub
71 self.nc_server = nc_server
72
73 def requestAvatar(self, avatarId, mind, *interfaces):
74 user = NetconfAvatar(avatarId, self.nc_server, self.grpc_stub)
75 return interfaces[0], user, user.logout
76
77
78class NCServer(factory.SSHFactory):
79 #
80 services = {
81 'ssh-userauth': userauth.SSHUserAuthServer,
82 'ssh-connection': connection.SSHConnection
83 }
84
85 def __init__(self,
86 netconf_port,
87 server_private_key_file,
88 server_public_key_file,
89 client_public_keys_file,
90 client_passwords_file,
91 grpc_stub):
92
93 self.netconf_port = netconf_port
94 self.server_private_key_file = server_private_key_file
95 self.server_public_key_file = server_public_key_file
96 self.client_public_keys_file = client_public_keys_file
97 self.client_passwords_file = client_passwords_file
98 self.grpc_stub = grpc_stub
99 self.connector = None
100 self.nc_client_map = {}
101 self.running = False
102 self.exiting = False
103
104 def start(self):
105 log.debug('starting')
106 if self.running:
107 return
108 self.running = True
109 reactor.callLater(0, self.start_ssh_server)
110 log.info('started')
111 return self
112
113 def stop(self):
114 log.debug('stopping')
115 self.exiting = True
116 self.connector.disconnect()
117 self.d_stopped.callback(None)
118 log.info('stopped')
119
120 def client_disconnected(self, result, handler, reason):
121 assert isinstance(handler, NetconfProtocolHandler)
122
123 log.info('client-disconnected', reason=reason)
124
125 # For now just nullify the handler
126 handler.close()
127
128 def client_connected(self, client_conn):
129 assert isinstance(client_conn, NetconfConnection)
130 log.info('client-connected')
131 handler = NetconfProtocolHandler(self, client_conn,
132 self.grpc_stub)
133 client_conn.proto_handler = handler
134 reactor.callLater(0, handler.start)
135
136 def setup_secure_access(self):
137 try:
138 from twisted.cred import portal
139 portal = portal.Portal(NetconfRealm(self, self.grpc_stub))
140
141 # setup userid-password access
142 password_file = '{}/{}'.format(CLIENT_CRED_DIRECTORY,
143 self.client_passwords_file)
144 portal.registerChecker(FilePasswordDB(password_file))
145
146 # setup access when client uses keys
147 keys_file = '{}/{}'.format(CLIENT_CRED_DIRECTORY,
148 self.client_public_keys_file)
149 with open(keys_file) as f:
150 users = [line.rstrip('\n') for line in f]
151 users_dict = {}
152 for user in users:
153 users_dict[user.split(':')[0]] = [
154 keys.Key.fromFile('{}/{}'.format(CLIENT_CRED_DIRECTORY,
155 user.split(':')[1]))]
156 sshDB = SSHPublicKeyChecker(InMemorySSHKeyDB(users_dict))
157 portal.registerChecker(sshDB)
158 return portal
159 except Exception as e:
160 log.error('setup-secure-access-fail', exception=repr(e))
161
162 @inlineCallbacks
163 def start_ssh_server(self):
164 try:
165 log.debug('starting', port=self.netconf_port)
166 self.portal = self.setup_secure_access()
167 self.connector = reactor.listenTCP(self.netconf_port, self)
168 log.debug('started', port=self.netconf_port)
169 self.d_stopped = Deferred()
170 self.d_stopped.callback(self.stop)
171 yield self.d_stopped
172 except Exception as e:
173 log.error('netconf-server-not-started', port=self.netconf_port,
174 exception=repr(e))
175
176 # Methods from SSHFactory
177 #
178
179 def protocol(self):
180 return SSHServerTransport()
181
182 def getPublicKeys(self):
183 key_file_name = '{}/{}'.format(KEYS_DIRECTORY,
184 self.server_public_key_file)
185 try:
186 publicKeys = {
187 'ssh-rsa': keys.Key.fromFile(key_file_name)
188 }
189 return publicKeys
190 except Exception as e:
191 log.error('cannot-retrieve-server-public-key',
192 filename=key_file_name, exception=repr(e))
193
194 def getPrivateKeys(self):
195 key_file_name = '{}/{}'.format(KEYS_DIRECTORY,
196 self.server_private_key_file)
197 try:
198 privateKeys = {
199 'ssh-rsa': keys.Key.fromFile(key_file_name)
200 }
201 return privateKeys
202 except Exception as e:
203 log.error('cannot-retrieve-server-private-key',
204 filename=key_file_name, exception=repr(e))