This commit consists of:
1) Add session management to netconf
2) Modularize the rpc call
3) Improve the error handling
4) Small bug fixes
Change-Id: I023edb76e3743b633ac87be4967d656e09e2b970
diff --git a/netconf/session/__init__.py b/netconf/session/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/netconf/session/__init__.py
diff --git a/netconf/session/nc_connection.py b/netconf/session/nc_connection.py
new file mode 100644
index 0000000..d8a2afe
--- /dev/null
+++ b/netconf/session/nc_connection.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import structlog
+from hexdump import hexdump
+from twisted.internet import protocol
+from twisted.internet.defer import inlineCallbacks, returnValue
+from common.utils.message_queue import MessageQueue
+from netconf.constants import Constants as C
+
+log = structlog.get_logger()
+
+from netconf import MAXSSHBUF
+
+
+class NetconfConnection(protocol.Protocol):
+ def __init__(self, data=None, avatar=None, max_chunk=MAXSSHBUF):
+ self.avatar = avatar
+ self.nc_server = self.avatar.get_nc_server()
+ self.rx = MessageQueue()
+ self.max_chunk = max_chunk
+ self.connected = True
+ self.proto_handler = None
+ self.exiting = False
+
+ def connectionLost(self, reason):
+ log.info('connection-lost')
+ self.connected = False
+ if not self.exiting:
+ self.proto_handler.stop('Connection-Lost')
+
+ def connectionMade(self):
+ log.info('connection-made')
+ self.nc_server.client_connected(self)
+
+ def dataReceived(self, data):
+ log.debug('data-received', len=len(data),
+ received=hexdump(data, result='return'))
+ assert len(data)
+ self.rx.put(data)
+
+ def processEnded(self, reason=None):
+ log.info('process-ended', reason=reason)
+ self.connected = False
+
+ def chunkit(self, msg, maxsend):
+ sz = len(msg)
+ left = 0
+ for unused in range(0, sz // maxsend):
+ right = left + maxsend
+ chunk = msg[left:right]
+ left = right
+ yield chunk
+ msg = msg[left:]
+ yield msg
+
+ def send_msg(self, msg, new_framing):
+ assert self.connected
+ # Apparently ssh has a bug that requires minimum of 64 bytes?
+ # This may not be sufficient to fix this.
+ if new_framing:
+ msg = "#{}\n{}\n##\n".format(len(msg), msg)
+ else:
+ msg += C.DELIMITER
+ for chunk in self.chunkit(msg, self.max_chunk - 64):
+ log.info('sending', chunk=chunk,
+ framing="1.1" if new_framing else "1.0")
+ # out = hexdump(chunk, result='return')
+ self.transport.write('{}\n'.format(chunk))
+
+ @inlineCallbacks
+ def receive_msg_any(self, new_framing):
+ assert self.connected
+ msg = yield self.recv(lambda _: True)
+ if new_framing:
+ returnValue(self._receive_11(msg))
+ else:
+ returnValue(self._receive_10(msg))
+
+ def _receive_10(self, msg):
+ # search for message end indicator
+ searchfrom = 0
+ eomidx = msg.find(C.DELIMITER, searchfrom)
+ if eomidx != -1:
+ log.info('received-msg', msg=msg[:eomidx])
+ return msg[:eomidx]
+ else:
+ log.error('no-message-end-indicators', msg=msg)
+ return msg
+
+ def _receive_11(self, msg):
+ # Message is received in the format "\n#{len}\n{msg}\n##\n"
+ # A message may have return characters within it
+ if msg:
+ log.info('received-msg-full', msg=msg)
+ msg = msg.split('\n')
+ if len(msg) > 2:
+ msg = ''.join(msg[2:(len(msg)-2)])
+ log.info('parsed-msg\n', msg=msg)
+ return msg
+ return None
+
+ def close_connection(self):
+ log.info('closing-connection')
+ self.exiting = True
+ self.transport.loseConnection()
+
+ def recv(self, predicate):
+ assert self.connected
+ return self.rx.get(predicate)
+
+ def recv_any(self, new_framing):
+ return self.recv(lambda _: True)
+
+ def recv_xid(self, xid):
+ return self.recv(lambda msg: msg.xid == xid)
diff --git a/netconf/session/nc_protocol_handler.py b/netconf/session/nc_protocol_handler.py
new file mode 100644
index 0000000..e0a4baf
--- /dev/null
+++ b/netconf/session/nc_protocol_handler.py
@@ -0,0 +1,246 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import structlog
+import io
+from lxml import etree
+from lxml.builder import E
+import netconf.nc_common.error as ncerror
+from netconf import NSMAP, qmap
+from utils import elm
+from twisted.internet.defer import inlineCallbacks, returnValue, Deferred
+from capabilities import Capabilities
+from netconf.nc_rpc.rpc_factory import get_rpc_factory_instance
+from netconf.constants import Constants as C
+
+log = structlog.get_logger()
+
+
+class NetconfProtocolError(Exception): pass
+
+
+class NetconfProtocolHandler:
+ def __init__(self, nc_server, nc_conn, session, grpc_stub):
+ self.started = True
+ self.conn = nc_conn
+ self.nc_server = nc_server
+ self.grpc_stub = grpc_stub
+ self.new_framing = False
+ self.capabilities = Capabilities()
+ self.session = session
+ self.exiting = False
+ self.connected = Deferred()
+ self.connected.addCallback(self.nc_server.client_disconnected,
+ self, None)
+
+ def send_message(self, msg):
+ self.conn.send_msg(C.XML_HEADER + msg, self.new_framing)
+
+ def receive_message(self):
+ return self.conn.receive_msg_any(self.new_framing)
+
+ def send_hello(self, caplist, session=None):
+ msg = elm(C.HELLO, attrib={C.XMLNS: NSMAP[C.NC]})
+ caps = E.capabilities(*[E.capability(x) for x in caplist])
+ msg.append(caps)
+
+ if session is not None:
+ msg.append(E(C.SESSION_ID, str(session.session_id)))
+ msg = etree.tostring(msg)
+ log.info("Sending HELLO", msg=msg)
+ msg = msg.decode('utf-8')
+ self.send_message(msg)
+
+ def send_rpc_reply(self, rpc_reply, origmsg):
+ reply = etree.Element(qmap(C.NC) + C.RPC_REPLY, attrib=origmsg.attrib,
+ nsmap=origmsg.nsmap)
+ try:
+ rpc_reply.getchildren
+ reply.append(rpc_reply)
+ except AttributeError:
+ reply.extend(rpc_reply)
+ ucode = etree.tounicode(reply, pretty_print=True)
+ log.info("RPC-Reply", reply=ucode)
+ self.send_message(ucode)
+
+ def set_framing_version(self):
+ if C.NETCONF_BASE_11 in self.capabilities.client_caps:
+ self.new_framing = True
+ elif C.NETCONF_BASE_10 not in self.capabilities.client_caps:
+ raise SessionError(
+ "Client doesn't implement 1.0 or 1.1 of netconf")
+
+ @inlineCallbacks
+ def open_session(self):
+ # The transport should be connected at this point.
+ try:
+ # Send hello message.
+ yield self.send_hello(self.capabilities.server_caps, self.session)
+ # Get reply
+ reply = yield self.receive_message()
+ log.info("reply-received", reply=reply)
+
+ # Parse reply
+ tree = etree.parse(io.BytesIO(reply.encode('utf-8')))
+ root = tree.getroot()
+ caps = root.xpath(C.CAPABILITY_XPATH, namespaces=NSMAP)
+
+ # Store capabilities
+ for cap in caps:
+ self.capabilities.add_client_capability(cap.text)
+
+ self.set_framing_version()
+ self.session.session_opened = True
+
+ log.info('session-opened', session_id=self.session.session_id,
+ framing="1.1" if self.new_framing else "1.0")
+ except Exception as e:
+ log.error('hello-failure', exception=repr(e))
+ self.stop(repr(e))
+ raise
+
+ @inlineCallbacks
+ def start(self):
+ log.info('starting')
+
+ try:
+ yield self.open_session()
+ while True:
+ if not self.session.session_opened:
+ break;
+ msg = yield self.receive_message()
+ yield self.handle_request(msg)
+
+ except Exception as e:
+ log.exception('exception', exception=repr(e))
+ self.stop(repr(e))
+
+ log.info('shutdown')
+ returnValue(self)
+
+ @inlineCallbacks
+ def handle_request(self, msg):
+ if not self.session.session_opened:
+ return
+
+ # Any error with XML encoding here is going to cause a session close
+ try:
+ tree = etree.parse(io.BytesIO(msg.encode('utf-8')))
+ if not tree:
+ raise ncerror.SessionError(msg, "Invalid XML from client.")
+ except etree.XMLSyntaxError:
+ log.error("malformed-message", msg=msg)
+ try:
+ error = ncerror.BadMsg(msg)
+ self.send_message(error.get_reply_msg())
+ except AttributeError:
+ log.error("attribute-error", msg=msg)
+ # close session
+ self.close()
+ return
+
+ rpcs = tree.xpath(C.RPC_XPATH, namespaces=NSMAP)
+ if not rpcs:
+ raise ncerror.SessionError(msg, "No rpc found")
+
+ # A message can have multiple rpc requests
+ rpc_factory = get_rpc_factory_instance()
+ for rpc in rpcs:
+ try:
+ # Validate message id is received
+ try:
+ msg_id = rpc.get(C.MESSAGE_ID)
+ log.info("Received-rpc-message-id", msg_id=msg_id)
+ except (TypeError, ValueError):
+ log.error('no-message-id', rpc=rpc)
+ raise ncerror.MissingElement(msg, C.MESSAGE_ID)
+
+ # Get a rpc handler
+ rpc_handler = rpc_factory.get_rpc_handler(rpc,
+ msg,
+ self.session)
+ if rpc_handler:
+ # set the parameters for this handler
+ response = yield rpc_handler.execute()
+ log.info('handler',
+ rpc_handler=rpc_handler,
+ is_error=response.is_error,
+ response=response)
+ self.send_rpc_reply(response.node, rpc)
+ if response.close_session:
+ log.info('response-closing-session', response=response)
+ self.close()
+ else:
+ log.error('no-rpc-handler',
+ request=msg,
+ session_id=self.session.session_id)
+ raise ncerror.NotImpl(msg)
+
+ except ncerror.BadMsg as err:
+ log.info('ncerror.BadMsg')
+ if self.new_framing:
+ self.send_message(err.get_reply_msg())
+ else:
+ # If we are 1.0 we have to simply close the connection
+ # as we are not allowed to send this error
+ log.error("Closing-1-0-session--malformed-message")
+ self.close()
+ except (ncerror.NotImpl, ncerror.MissingElement) as e:
+ log.info('error', repr(e))
+ self.send_message(e.get_reply_msg())
+ except Exception as ex:
+ log.info('Exception', repr(ex))
+ error = ncerror.ServerException(rpc, ex)
+ self.send_message(error.get_reply_msg())
+
+ # @inlineCallbacks
+ # def invoke_method(self, rpcname, rpc, params):
+ # try:
+ # # Handle any namespaces or prefixes in the tag, other than
+ # # "nc" which was removed above. Of course, this does not handle
+ # # namespace collisions, but that seems reasonable for now.
+ # rpcname = rpcname.rpartition("}")[-1]
+ # method_name = "rpc_" + rpcname.replace('-', '_')
+ # method = getattr(self.methods, method_name,
+ # self._rpc_not_implemented)
+ # log.info("invoking-method", method=method_name)
+ # reply = yield method(self, rpc, *params)
+ # returnValue(reply)
+ # except NotImplementedError:
+ # raise ncerror.NotImpl(rpc)
+
+ def stop(self, reason):
+ if not self.exiting:
+ log.debug('stopping')
+ self.exiting = True
+ if self.session.session_opened:
+ # TODO: send a closing message to the far end
+ self.conn.close_connection()
+ self.nc_server.session_mgr.remove_session(self.session)
+ self.session.session_opened = False
+ self.connected.callback(None)
+ log.info('stopped')
+
+ def close(self):
+ if not self.exiting:
+ log.debug('closing-client')
+ self.exiting = True
+ if self.session.session_opened:
+ self.conn.close_connection()
+ self.nc_server.session_mgr.remove_session(self.session)
+ self.session.session_opened = False
+ self.connected.callback(None)
+ log.info('closing-client')
diff --git a/netconf/session/session.py b/netconf/session/session.py
new file mode 100644
index 0000000..51979f7
--- /dev/null
+++ b/netconf/session/session.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from time import time
+import structlog
+
+log = structlog.get_logger()
+
+
+class Session:
+ def __init__(self, session_id, user):
+ self.session_id = session_id
+ self.user = user
+ self.started_at = time()
+ self.session_opened = False
diff --git a/netconf/session/session_mgr.py b/netconf/session/session_mgr.py
new file mode 100644
index 0000000..ee4382f
--- /dev/null
+++ b/netconf/session/session_mgr.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+#
+# Copyright 2016 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from session import Session
+import structlog
+
+log = structlog.get_logger()
+
+class SessionManager:
+ instance = None
+
+ def __init__(self):
+ self.next_session_id = 1
+ self.sessions = {}
+
+ def create_session(self, user):
+ session = Session(self.next_session_id, user)
+ self.sessions[self.next_session_id] = session
+ self.next_session_id += 1
+ return session
+
+ def remove_session(self, session):
+ session_id = session.session_id
+ if session_id in self.sessions.keys():
+ del self.sessions[session_id]
+ log.info('remove-session', session_id=session_id)
+ else:
+ log.error('invalid-session', session_id=session_id)
+
+
+def get_session_manager_instance():
+ if SessionManager.instance == None:
+ SessionManager.instance = SessionManager()
+ return SessionManager.instance
\ No newline at end of file