aboutsummaryrefslogtreecommitdiffstats
path: root/src/libcharon/plugins/vici/python
diff options
context:
space:
mode:
Diffstat (limited to 'src/libcharon/plugins/vici/python')
-rw-r--r--src/libcharon/plugins/vici/python/.gitignore5
-rw-r--r--src/libcharon/plugins/vici/python/LICENSE19
-rw-r--r--src/libcharon/plugins/vici/python/MANIFEST.in1
-rw-r--r--src/libcharon/plugins/vici/python/Makefile.am33
-rw-r--r--src/libcharon/plugins/vici/python/setup.py.in34
-rw-r--r--src/libcharon/plugins/vici/python/vici/__init__.py1
-rw-r--r--src/libcharon/plugins/vici/python/vici/compat.py14
-rw-r--r--src/libcharon/plugins/vici/python/vici/exception.py10
-rw-r--r--src/libcharon/plugins/vici/python/vici/protocol.py196
-rw-r--r--src/libcharon/plugins/vici/python/vici/session.py327
-rw-r--r--src/libcharon/plugins/vici/python/vici/test/__init__.py0
-rw-r--r--src/libcharon/plugins/vici/python/vici/test/test_protocol.py144
12 files changed, 784 insertions, 0 deletions
diff --git a/src/libcharon/plugins/vici/python/.gitignore b/src/libcharon/plugins/vici/python/.gitignore
new file mode 100644
index 000000000..5c4589841
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/.gitignore
@@ -0,0 +1,5 @@
+*.pyc
+build
+dist
+vici.egg-info
+setup.py
diff --git a/src/libcharon/plugins/vici/python/LICENSE b/src/libcharon/plugins/vici/python/LICENSE
new file mode 100644
index 000000000..111523ca8
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2015 Björn Schuberg
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/src/libcharon/plugins/vici/python/MANIFEST.in b/src/libcharon/plugins/vici/python/MANIFEST.in
new file mode 100644
index 000000000..1aba38f67
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE
diff --git a/src/libcharon/plugins/vici/python/Makefile.am b/src/libcharon/plugins/vici/python/Makefile.am
new file mode 100644
index 000000000..f51737870
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/Makefile.am
@@ -0,0 +1,33 @@
+EXTRA_DIST = LICENSE MANIFEST.in \
+ setup.py.in \
+ vici/test/__init__.py \
+ vici/test/test_protocol.py \
+ vici/__init__.py \
+ vici/compat.py \
+ vici/exception.py \
+ vici/protocol.py \
+ vici/session.py
+
+setup.py: $(srcdir)/setup.py.in
+ $(AM_V_GEN) sed \
+ -e "s:@EGG_VERSION@:$(PACKAGE_VERSION):" \
+ $(srcdir)/setup.py.in > $@
+
+all-local: dist/vici-$(PACKAGE_VERSION)-py$(PYTHON_VERSION).egg
+
+dist/vici-$(PACKAGE_VERSION)-py$(PYTHON_VERSION).egg: $(EXTRA_DIST) setup.py
+ (cd $(srcdir); $(PYTHON) setup.py bdist_egg \
+ -b $(shell readlink -f $(builddir))/build \
+ -d $(shell readlink -f $(builddir))/dist)
+
+clean-local: setup.py
+ $(PYTHON) setup.py clean -a
+ rm -rf vici.egg-info dist setup.py
+
+install-exec-local: dist/vici-$(PACKAGE_VERSION)-py$(PYTHON_VERSION).egg
+ $(EASY_INSTALL) $(PYTHONEGGINSTALLDIR) \
+ dist/vici-$(PACKAGE_VERSION)-py$(PYTHON_VERSION).egg
+
+if USE_PY_TEST
+ TESTS = $(PY_TEST)
+endif
diff --git a/src/libcharon/plugins/vici/python/setup.py.in b/src/libcharon/plugins/vici/python/setup.py.in
new file mode 100644
index 000000000..0e4ad8236
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/setup.py.in
@@ -0,0 +1,34 @@
+from setuptools import setup
+
+
+long_description = (
+ "The strongSwan VICI protocol allows external application to monitor, "
+ "configure and control the IKE daemon charon. This python package provides "
+ "a native client side implementation of the VICI protocol, well suited to "
+ "script automated tasks in a reliable way."
+)
+
+setup(
+ name="vici",
+ version="@EGG_VERSION@",
+ description="Native python interface for strongSwan VICI",
+ author="Bjorn Schuberg",
+ url="https://wiki.strongswan.org/projects/strongswan/wiki/Vici",
+ license="MIT",
+ packages=["vici"],
+ long_description=long_description,
+ include_package_data=True,
+ classifiers=(
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "Intended Audience :: System Administrators",
+ "License :: OSI Approved :: MIT License",
+ "Natural Language :: English",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3.2",
+ "Programming Language :: Python :: 3.3",
+ "Programming Language :: Python :: 3.4",
+ "Topic :: Security",
+ "Topic :: Software Development :: Libraries",
+ )
+)
diff --git a/src/libcharon/plugins/vici/python/vici/__init__.py b/src/libcharon/plugins/vici/python/vici/__init__.py
new file mode 100644
index 000000000..d314325b6
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/__init__.py
@@ -0,0 +1 @@
+from .session import Session
diff --git a/src/libcharon/plugins/vici/python/vici/compat.py b/src/libcharon/plugins/vici/python/vici/compat.py
new file mode 100644
index 000000000..b5f46992e
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/compat.py
@@ -0,0 +1,14 @@
+# Help functions for compatibility between python version 2 and 3
+
+
+# From http://legacy.python.org/dev/peps/pep-0469
+try:
+ dict.iteritems
+except AttributeError:
+ # python 3
+ def iteritems(d):
+ return iter(d.items())
+else:
+ # python 2
+ def iteritems(d):
+ return d.iteritems()
diff --git a/src/libcharon/plugins/vici/python/vici/exception.py b/src/libcharon/plugins/vici/python/vici/exception.py
new file mode 100644
index 000000000..36384e556
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/exception.py
@@ -0,0 +1,10 @@
+"""Exception types that may be thrown by this library."""
+
+class DeserializationException(Exception):
+ """Encountered an unexpected byte sequence or missing element type."""
+
+class SessionException(Exception):
+ """Session request exception."""
+
+class CommandException(Exception):
+ """Command result exception."""
diff --git a/src/libcharon/plugins/vici/python/vici/protocol.py b/src/libcharon/plugins/vici/python/vici/protocol.py
new file mode 100644
index 000000000..855a7b2e2
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/protocol.py
@@ -0,0 +1,196 @@
+import io
+import socket
+import struct
+
+from collections import namedtuple
+from collections import OrderedDict
+
+from .compat import iteritems
+from .exception import DeserializationException
+
+
+class Transport(object):
+ HEADER_LENGTH = 4
+ MAX_SEGMENT = 512 * 1024
+
+ def __init__(self, sock):
+ self.socket = sock
+
+ def send(self, packet):
+ self.socket.sendall(struct.pack("!I", len(packet)) + packet)
+
+ def receive(self):
+ raw_length = self.socket.recv(self.HEADER_LENGTH)
+ length, = struct.unpack("!I", raw_length)
+ payload = self.socket.recv(length)
+ return payload
+
+ def close(self):
+ self.socket.shutdown(socket.SHUT_RDWR)
+ self.socket.close()
+
+
+class Packet(object):
+ CMD_REQUEST = 0 # Named request message
+ CMD_RESPONSE = 1 # Unnamed response message for a request
+ CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown
+ EVENT_REGISTER = 3 # Named event registration request
+ EVENT_UNREGISTER = 4 # Named event de-registration request
+ EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration
+ EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed
+ EVENT = 7 # Named event message
+
+ ParsedPacket = namedtuple(
+ "ParsedPacket",
+ ["response_type", "payload"]
+ )
+
+ ParsedEventPacket = namedtuple(
+ "ParsedEventPacket",
+ ["response_type", "event_type", "payload"]
+ )
+
+ @classmethod
+ def _named_request(cls, request_type, request, message=None):
+ request = request.encode()
+ payload = struct.pack("!BB", request_type, len(request)) + request
+ if message is not None:
+ return payload + message
+ else:
+ return payload
+
+ @classmethod
+ def request(cls, command, message=None):
+ return cls._named_request(cls.CMD_REQUEST, command, message)
+
+ @classmethod
+ def register_event(cls, event_type):
+ return cls._named_request(cls.EVENT_REGISTER, event_type)
+
+ @classmethod
+ def unregister_event(cls, event_type):
+ return cls._named_request(cls.EVENT_UNREGISTER, event_type)
+
+ @classmethod
+ def parse(cls, packet):
+ stream = FiniteStream(packet)
+ response_type, = struct.unpack("!B", stream.read(1))
+
+ if response_type == cls.EVENT:
+ length, = struct.unpack("!B", stream.read(1))
+ event_type = stream.read(length)
+ return cls.ParsedEventPacket(response_type, event_type, stream)
+ else:
+ return cls.ParsedPacket(response_type, stream)
+
+
+class Message(object):
+ SECTION_START = 1 # Begin a new section having a name
+ SECTION_END = 2 # End a previously started section
+ KEY_VALUE = 3 # Define a value for a named key in the section
+ LIST_START = 4 # Begin a named list for list items
+ LIST_ITEM = 5 # Define an unnamed item value in the current list
+ LIST_END = 6 # End a previously started list
+
+ @classmethod
+ def serialize(cls, message):
+ def encode_named_type(marker, name):
+ name = name.encode()
+ return struct.pack("!BB", marker, len(name)) + name
+
+ def encode_blob(value):
+ if not isinstance(value, bytes):
+ value = str(value).encode()
+ return struct.pack("!H", len(value)) + value
+
+ def serialize_list(lst):
+ segment = bytes()
+ for item in lst:
+ segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
+ return segment
+
+ def serialize_dict(d):
+ segment = bytes()
+ for key, value in iteritems(d):
+ if isinstance(value, dict):
+ segment += (
+ encode_named_type(cls.SECTION_START, key)
+ + serialize_dict(value)
+ + struct.pack("!B", cls.SECTION_END)
+ )
+ elif isinstance(value, list):
+ segment += (
+ encode_named_type(cls.LIST_START, key)
+ + serialize_list(value)
+ + struct.pack("!B", cls.LIST_END)
+ )
+ else:
+ segment += (
+ encode_named_type(cls.KEY_VALUE, key)
+ + encode_blob(value)
+ )
+ return segment
+
+ return serialize_dict(message)
+
+ @classmethod
+ def deserialize(cls, stream):
+ def decode_named_type(stream):
+ length, = struct.unpack("!B", stream.read(1))
+ return stream.read(length).decode()
+
+ def decode_blob(stream):
+ length, = struct.unpack("!H", stream.read(2))
+ return stream.read(length)
+
+ def decode_list_item(stream):
+ marker, = struct.unpack("!B", stream.read(1))
+ while marker == cls.LIST_ITEM:
+ yield decode_blob(stream)
+ marker, = struct.unpack("!B", stream.read(1))
+
+ if marker != cls.LIST_END:
+ raise DeserializationException(
+ "Expected end of list at {pos}".format(pos=stream.tell())
+ )
+
+ section = OrderedDict()
+ section_stack = []
+ while stream.has_more():
+ element_type, = struct.unpack("!B", stream.read(1))
+ if element_type == cls.SECTION_START:
+ section_name = decode_named_type(stream)
+ new_section = OrderedDict()
+ section[section_name] = new_section
+ section_stack.append(section)
+ section = new_section
+
+ elif element_type == cls.LIST_START:
+ list_name = decode_named_type(stream)
+ section[list_name] = [item for item in decode_list_item(stream)]
+
+ elif element_type == cls.KEY_VALUE:
+ key = decode_named_type(stream)
+ section[key] = decode_blob(stream)
+
+ elif element_type == cls.SECTION_END:
+ if len(section_stack):
+ section = section_stack.pop()
+ else:
+ raise DeserializationException(
+ "Unexpected end of section at {pos}".format(
+ pos=stream.tell()
+ )
+ )
+
+ if len(section_stack):
+ raise DeserializationException("Expected end of section")
+ return section
+
+
+class FiniteStream(io.BytesIO):
+ def __len__(self):
+ return len(self.getvalue())
+
+ def has_more(self):
+ return self.tell() < len(self)
diff --git a/src/libcharon/plugins/vici/python/vici/session.py b/src/libcharon/plugins/vici/python/vici/session.py
new file mode 100644
index 000000000..dee58699d
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/session.py
@@ -0,0 +1,327 @@
+import collections
+import socket
+
+from .exception import SessionException, CommandException
+from .protocol import Transport, Packet, Message
+
+
+class Session(object):
+ def __init__(self, sock=None):
+ if sock is None:
+ sock = socket.socket(socket.AF_UNIX)
+ sock.connect("/var/run/charon.vici")
+ self.handler = SessionHandler(Transport(sock))
+
+ def version(self):
+ """Retrieve daemon and system specific version information.
+
+ :return: daemon and system specific version information
+ :rtype: dict
+ """
+ return self.handler.request("version")
+
+ def stats(self):
+ """Retrieve IKE daemon statistics and load information.
+
+ :return: IKE daemon statistics and load information
+ :rtype: dict
+ """
+ return self.handler.request("stats")
+
+ def reload_settings(self):
+ """Reload strongswan.conf settings and any plugins supporting reload.
+ """
+ self.handler.request("reload-settings")
+
+ def initiate(self, sa):
+ """Initiate an SA.
+
+ :param sa: the SA to initiate
+ :type sa: dict
+ :return: generator for logs emitted as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("initiate", "control-log", sa)
+
+ def terminate(self, sa):
+ """Terminate an SA.
+
+ :param sa: the SA to terminate
+ :type sa: dict
+ :return: generator for logs emitted as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("terminate", "control-log", sa)
+
+ def install(self, policy):
+ """Install a trap, drop or bypass policy defined by a CHILD_SA config.
+
+ :param policy: policy to install
+ :type policy: dict
+ """
+ self.handler.request("install", policy)
+
+ def uninstall(self, policy):
+ """Uninstall a trap, drop or bypass policy defined by a CHILD_SA config.
+
+ :param policy: policy to uninstall
+ :type policy: dict
+ """
+ self.handler.request("uninstall", policy)
+
+ def list_sas(self, filters=None):
+ """Retrieve active IKE_SAs and associated CHILD_SAs.
+
+ :param filters: retrieve only matching IKE_SAs (optional)
+ :type filters: dict
+ :return: generator for active IKE_SAs and associated CHILD_SAs as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("list-sas", "list-sa", filters)
+
+ def list_policies(self, filters=None):
+ """Retrieve installed trap, drop and bypass policies.
+
+ :param filters: retrieve only matching policies (optional)
+ :type filters: dict
+ :return: generator for installed trap, drop and bypass policies as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("list-policies", "list-policy",
+ filters)
+
+ def list_conns(self, filters=None):
+ """Retrieve loaded connections.
+
+ :param filters: retrieve only matching configuration names (optional)
+ :type filters: dict
+ :return: generator for loaded connections as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("list-conns", "list-conn",
+ filters)
+
+ def get_conns(self):
+ """Retrieve connection names loaded exclusively over vici.
+
+ :return: connection names
+ :rtype: dict
+ """
+ return self.handler.request("get-conns")
+
+ def list_certs(self, filters=None):
+ """Retrieve loaded certificates.
+
+ :param filters: retrieve only matching certificates (optional)
+ :type filters: dict
+ :return: generator for loaded certificates as dict
+ :rtype: generator
+ """
+ return self.handler.streamed_request("list-certs", "list-cert", filters)
+
+ def load_conn(self, connection):
+ """Load a connection definition into the daemon.
+
+ :param connection: connection definition
+ :type connection: dict
+ """
+ self.handler.request("load-conn", connection)
+
+ def unload_conn(self, name):
+ """Unload a connection definition.
+
+ :param name: connection definition name
+ :type name: dict
+ """
+ self.handler.request("unload-conn", name)
+
+ def load_cert(self, certificate):
+ """Load a certificate into the daemon.
+
+ :param certificate: PEM or DER encoded certificate
+ :type certificate: dict
+ """
+ self.handler.request("load-cert", certificate)
+
+ def load_key(self, private_key):
+ """Load a private key into the daemon.
+
+ :param private_key: PEM or DER encoded key
+ """
+ self.handler.request("load-key", private_key)
+
+ def load_shared(self, secret):
+ """Load a shared IKE PSK, EAP or XAuth secret into the daemon.
+
+ :param secret: shared IKE PSK, EAP or XAuth secret
+ :type secret: dict
+ """
+ self.handler.request("load-shared", secret)
+
+ def clear_creds(self):
+ """Clear credentials loaded over vici.
+
+ Clear all loaded certificate, private key and shared key credentials.
+ This affects only credentials loaded over vici, but additionally
+ flushes the credential cache.
+ """
+ self.handler.request("clear-creds")
+
+ def load_pool(self, pool):
+ """Load a virtual IP pool.
+
+ Load an in-memory virtual IP and configuration attribute pool.
+ Existing pools with the same name get updated, if possible.
+
+ :param pool: virtual IP and configuration attribute pool
+ :type pool: dict
+ """
+ return self.handler.request("load-pool", pool)
+
+ def unload_pool(self, pool_name):
+ """Unload a virtual IP pool.
+
+ Unload a previously loaded virtual IP and configuration attribute pool.
+ Unloading fails for pools with leases currently online.
+
+ :param pool_name: pool by name
+ :type pool_name: dict
+ """
+ self.handler.request("unload-pool", pool_name)
+
+ def get_pools(self):
+ """Retrieve loaded pools.
+
+ :return: loaded pools
+ :rtype: dict
+ """
+ return self.handler.request("get-pools")
+
+
+class SessionHandler(object):
+ """Handles client command execution requests over vici."""
+
+ def __init__(self, transport):
+ self.transport = transport
+
+ def _communicate(self, packet):
+ """Send packet over transport and parse response.
+
+ :param packet: packet to send
+ :type packet: :py:class:`vici.protocol.Packet`
+ :return: parsed packet in a tuple with message type and payload
+ :rtype: :py:class:`collections.namedtuple`
+ """
+ self.transport.send(packet)
+ return Packet.parse(self.transport.receive())
+
+ def request(self, command, message=None):
+ """Send request with an optional message.
+
+ :param command: command to send
+ :type command: str
+ :param message: message (optional)
+ :type message: str
+ :return: command result
+ :rtype: dict
+ """
+ if message is not None:
+ message = Message.serialize(message)
+ packet = Packet.request(command, message)
+ response = self._communicate(packet)
+
+ if response.response_type != Packet.CMD_RESPONSE:
+ raise SessionException(
+ "Unexpected response type {type}, "
+ "expected '{response}' (CMD_RESPONSE)".format(
+ type=response.response_type,
+ response=Packet.CMD_RESPONSE
+ )
+ )
+
+ command_response = Message.deserialize(response.payload)
+ if "success" in command_response:
+ if command_response["success"] != b"yes":
+ raise CommandException(
+ "Command failed: {errmsg}".format(
+ errmsg=command_response["errmsg"]
+ )
+ )
+
+ return command_response
+
+ def streamed_request(self, command, event_stream_type, message=None):
+ """Send command request and collect and return all emitted events.
+
+ :param command: command to send
+ :type command: str
+ :param event_stream_type: event type emitted on command execution
+ :type event_stream_type: str
+ :param message: message (optional)
+ :type message: str
+ :return: generator for streamed event responses as dict
+ :rtype: generator
+ """
+ if message is not None:
+ message = Message.serialize(message)
+
+ # subscribe to event stream
+ packet = Packet.register_event(event_stream_type)
+ response = self._communicate(packet)
+
+ if response.response_type != Packet.EVENT_CONFIRM:
+ raise SessionException(
+ "Unexpected response type {type}, "
+ "expected '{confirm}' (EVENT_CONFIRM)".format(
+ type=response.response_type,
+ confirm=Packet.EVENT_CONFIRM,
+ )
+ )
+
+ # issue command, and read any event messages
+ packet = Packet.request(command, message)
+ self.transport.send(packet)
+ exited = False
+ while True:
+ response = Packet.parse(self.transport.receive())
+ if response.response_type == Packet.EVENT:
+ if not exited:
+ try:
+ yield Message.deserialize(response.payload)
+ except GeneratorExit:
+ exited = True
+ pass
+ else:
+ break
+
+ if response.response_type == Packet.CMD_RESPONSE:
+ command_response = Message.deserialize(response.payload)
+ else:
+ raise SessionException(
+ "Unexpected response type {type}, "
+ "expected '{response}' (CMD_RESPONSE)".format(
+ type=response.response_type,
+ response=Packet.CMD_RESPONSE
+ )
+ )
+
+ # unsubscribe from event stream
+ packet = Packet.unregister_event(event_stream_type)
+ response = self._communicate(packet)
+ if response.response_type != Packet.EVENT_CONFIRM:
+ raise SessionException(
+ "Unexpected response type {type}, "
+ "expected '{confirm}' (EVENT_CONFIRM)".format(
+ type=response.response_type,
+ confirm=Packet.EVENT_CONFIRM,
+ )
+ )
+
+ # evaluate command result, if any
+ if "success" in command_response:
+ if command_response["success"] != b"yes":
+ raise CommandException(
+ "Command failed: {errmsg}".format(
+ errmsg=command_response["errmsg"]
+ )
+ )
diff --git a/src/libcharon/plugins/vici/python/vici/test/__init__.py b/src/libcharon/plugins/vici/python/vici/test/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/test/__init__.py
diff --git a/src/libcharon/plugins/vici/python/vici/test/test_protocol.py b/src/libcharon/plugins/vici/python/vici/test/test_protocol.py
new file mode 100644
index 000000000..a1f202d79
--- /dev/null
+++ b/src/libcharon/plugins/vici/python/vici/test/test_protocol.py
@@ -0,0 +1,144 @@
+import pytest
+
+from ..protocol import Packet, Message, FiniteStream
+from ..exception import DeserializationException
+
+
+class TestPacket(object):
+ # test data definitions for outgoing packet types
+ cmd_request = b"\x00\x0c" b"command_type"
+ cmd_request_msg = b"\x00\x07" b"command" b"payload"
+ event_register = b"\x03\x0a" b"event_type"
+ event_unregister = b"\x04\x0a" b"event_type"
+
+ # test data definitions for incoming packet types
+ cmd_response = b"\x01" b"reply"
+ cmd_unknown = b"\x02"
+ event_confirm = b"\x05"
+ event_unknown = b"\x06"
+ event = b"\x07\x03" b"log" b"message"
+
+ def test_request(self):
+ assert Packet.request("command_type") == self.cmd_request
+ assert Packet.request("command", b"payload") == self.cmd_request_msg
+
+ def test_register_event(self):
+ assert Packet.register_event("event_type") == self.event_register
+
+ def test_unregister_event(self):
+ assert Packet.unregister_event("event_type") == self.event_unregister
+
+ def test_parse(self):
+ parsed_cmd_response = Packet.parse(self.cmd_response)
+ assert parsed_cmd_response.response_type == Packet.CMD_RESPONSE
+ assert parsed_cmd_response.payload.getvalue() == self.cmd_response
+
+ parsed_cmd_unknown = Packet.parse(self.cmd_unknown)
+ assert parsed_cmd_unknown.response_type == Packet.CMD_UNKNOWN
+ assert parsed_cmd_unknown.payload.getvalue() == self.cmd_unknown
+
+ parsed_event_confirm = Packet.parse(self.event_confirm)
+ assert parsed_event_confirm.response_type == Packet.EVENT_CONFIRM
+ assert parsed_event_confirm.payload.getvalue() == self.event_confirm
+
+ parsed_event_unknown = Packet.parse(self.event_unknown)
+ assert parsed_event_unknown.response_type == Packet.EVENT_UNKNOWN
+ assert parsed_event_unknown.payload.getvalue() == self.event_unknown
+
+ parsed_event = Packet.parse(self.event)
+ assert parsed_event.response_type == Packet.EVENT
+ assert parsed_event.payload.getvalue() == self.event
+
+
+class TestMessage(object):
+ """Message (de)serialization test."""
+
+ # data definitions for test of de(serialization)
+ # serialized messages holding a section
+ ser_sec_unclosed = b"\x01\x08unclosed"
+ ser_sec_single = b"\x01\x07section\x02"
+ ser_sec_nested = b"\x01\x05outer\x01\x0asubsection\x02\x02"
+
+ # serialized messages holding a list
+ ser_list_invalid = b"\x04\x07invalid\x05\x00\x02e1\x02\x03sec\x06"
+ ser_list_0_item = b"\x04\x05empty\x06"
+ ser_list_1_item = b"\x04\x01l\x05\x00\x02e1\x06"
+ ser_list_2_item = b"\x04\x01l\x05\x00\x02e1\x05\x00\x02e2\x06"
+
+ # serialized messages with key value pairs
+ ser_kv_pair = b"\x03\x03key\x00\x05value"
+ ser_kv_zero = b"\x03\x0azerolength\x00\x00"
+
+ # deserialized messages holding a section
+ des_sec_single = { "section": {} }
+ des_sec_nested = { "outer": { "subsection": {} } }
+
+ # deserialized messages holding a list
+ des_list_0_item = { "empty": [] }
+ des_list_1_item = { "l": [ b"e1" ] }
+ des_list_2_item = { "l": [ b"e1", b"e2" ] }
+
+ # deserialized messages with key value pairs
+ des_kv_pair = { "key": b"value" }
+ des_kv_zero = { "zerolength": b"" }
+
+ def test_section_serialization(self):
+ assert Message.serialize(self.des_sec_single) == self.ser_sec_single
+ assert Message.serialize(self.des_sec_nested) == self.ser_sec_nested
+
+ def test_list_serialization(self):
+ assert Message.serialize(self.des_list_0_item) == self.ser_list_0_item
+ assert Message.serialize(self.des_list_1_item) == self.ser_list_1_item
+ assert Message.serialize(self.des_list_2_item) == self.ser_list_2_item
+
+ def test_key_serialization(self):
+ assert Message.serialize(self.des_kv_pair) == self.ser_kv_pair
+ assert Message.serialize(self.des_kv_zero) == self.ser_kv_zero
+
+ def test_section_deserialization(self):
+ single = Message.deserialize(FiniteStream(self.ser_sec_single))
+ nested = Message.deserialize(FiniteStream(self.ser_sec_nested))
+
+ assert single == self.des_sec_single
+ assert nested == self.des_sec_nested
+
+ with pytest.raises(DeserializationException):
+ Message.deserialize(FiniteStream(self.ser_sec_unclosed))
+
+ def test_list_deserialization(self):
+ l0 = Message.deserialize(FiniteStream(self.ser_list_0_item))
+ l1 = Message.deserialize(FiniteStream(self.ser_list_1_item))
+ l2 = Message.deserialize(FiniteStream(self.ser_list_2_item))
+
+ assert l0 == self.des_list_0_item
+ assert l1 == self.des_list_1_item
+ assert l2 == self.des_list_2_item
+
+ with pytest.raises(DeserializationException):
+ Message.deserialize(FiniteStream(self.ser_list_invalid))
+
+ def test_key_deserialization(self):
+ pair = Message.deserialize(FiniteStream(self.ser_kv_pair))
+ zerolength = Message.deserialize(FiniteStream(self.ser_kv_zero))
+
+ assert pair == self.des_kv_pair
+ assert zerolength == self.des_kv_zero
+
+ def test_roundtrip(self):
+ message = {
+ "key1": "value1",
+ "section1": {
+ "sub-section": {
+ "key2": b"value2",
+ },
+ "list1": [ "item1", "item2" ],
+ },
+ }
+ serialized_message = FiniteStream(Message.serialize(message))
+ deserialized_message = Message.deserialize(serialized_message)
+
+ # ensure that list items and key values remain as undecoded bytes
+ deserialized_section = deserialized_message["section1"]
+ assert deserialized_message["key1"] == b"value1"
+ assert deserialized_section["sub-section"]["key2"] == b"value2"
+ assert deserialized_section["list1"] == [ b"item1", b"item2" ]