diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..3545e18 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,13 @@ +[FORMAT] +max-line-length=180 + +[MASTER] +disable= + C0114, #missing-module-docstring + C0115, #missing-class-docstring + C0116, #missing-function-docstring + W0223, #abstract-method + E1101, #no-member + R0903, #too-few-public-methods + R0902, #too-many-instance-attributes + W0707, #too-many-statements \ No newline at end of file diff --git a/LICENSE b/LICENSE index c199998..0c0d8e6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Nasdaq +Copyright (c) 2024 Nasdaq Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api_reference.rst b/docs/api_reference.rst new file mode 100644 index 0000000..db656ea --- /dev/null +++ b/docs/api_reference.rst @@ -0,0 +1,9 @@ +API Reference +============= + +.. toctree:: + :maxdepth: 1 + :caption: Contents: + + common_api_reference + soup_api_reference \ No newline at end of file diff --git a/docs/common_api_reference.rst b/docs/common_api_reference.rst new file mode 100644 index 0000000..7e3a0a6 --- /dev/null +++ b/docs/common_api_reference.rst @@ -0,0 +1,10 @@ +Common api reference +==================== + +.. automodule:: nasdaq_protocols.common.message_queue + +.. automodule:: nasdaq_protocols.common.session + +.. automodule:: nasdaq_protocols.common.types + +.. automodule:: nasdaq_protocols.common.utils \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..a9342c1 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,38 @@ +import os +import sys +sys.path.insert(0, os.path.abspath('../src')) + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'Nasdaq Protocols Python Library' +copyright = 'Copyright (c) 2024 Nasdaq' +author = 'Sam Daniel Thangarajan' +release = '0.0.1' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration +# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html +extensions = ['sphinx.ext.autodoc',] +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + +autodoc_member_order = 'bysource' +autodoc_default_options = { + 'members': True, + 'member-order': 'bysource', + 'show-inheritance': True, +} +autodoc_typehints = 'description' + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output +# https://www.sphinx-doc.org/en/master/usage/theming.html [themes] +html_theme = 'agogo' +html_static_path = ['_static'] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..bf4d367 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,25 @@ +.. nasdaq-protocols documentation master file, created by + sphinx-quickstart on Thu Jan 4 23:40:15 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to nasdaq-protocols's documentation! +============================================ + +.. automodule:: nasdaq_protocols + +.. toctree:: + :maxdepth: 1 + :caption: Contents: + + install + api_reference + user_guide + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/install.rst b/docs/install.rst new file mode 100644 index 0000000..9ef5f4a --- /dev/null +++ b/docs/install.rst @@ -0,0 +1,8 @@ +Installation +============ + +To use nasdaq-protocols, install it using pip: + +.. code-block:: console + + (.venv) $ pip install nasdaq-protocols \ No newline at end of file diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/soup_api_reference.rst b/docs/soup_api_reference.rst new file mode 100644 index 0000000..828f9f3 --- /dev/null +++ b/docs/soup_api_reference.rst @@ -0,0 +1,24 @@ +SOUP api reference +================== + +Soup +---- +.. automodule:: nasdaq_protocols.soup + + +Soup Session +^^^^^^^^^^^^ +.. autoclass:: nasdaq_protocols.soup.session.SoupSessionId + :show-inheritance: True + +.. autoclass:: nasdaq_protocols.soup.session.SoupClientSession + :show-inheritance: True + :undoc-members: ['send_heartbeat'] + +.. autoclass:: nasdaq_protocols.soup.session.SoupServerSession + :show-inheritance: True + + +Soup Messages +^^^^^^^^^^^^^ +.. automodule:: nasdaq_protocols.soup.core \ No newline at end of file diff --git a/docs/user_guide.rst b/docs/user_guide.rst new file mode 100644 index 0000000..554cd16 --- /dev/null +++ b/docs/user_guide.rst @@ -0,0 +1,35 @@ +User Guide +========== + +SOUP +____ + +Connect to Soup session +----------------------- +Example of connecting to a Soup session and receiving messages. + +.. code-block:: python + + import asyncio + from nasdaq_protocols import Soup + + stopped = asyncio.Event() + + async def on_msg(msg): + print(msg) + + async def on_close(): + stopped.set() + print('closed') + + async def main(): + session = await soup.connect_async( + ('host', port), 'user', 'password', + sequence=1, on_msg_coro=on_msg, on_close_coro=on_close + ) + await stopped.wait() + + if __name__ == '__main__': + asyncio.run(main()) + +*A simple soup tail program* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9ee5791..10dfde4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -attrs==23.1.0 -pytest==7.4.3 -pytest-cov==4.1.0 -pytest-asyncio==0.23.0 \ No newline at end of file +attrs>=23.1 +pytest>=7.4 +pytest-cov>=4.1 +pytest-asyncio>=0.23 +sphinx \ No newline at end of file diff --git a/src/nasdaq_protocols/__init__.py b/src/nasdaq_protocols/__init__.py index e69de29..0e05291 100644 --- a/src/nasdaq_protocols/__init__.py +++ b/src/nasdaq_protocols/__init__.py @@ -0,0 +1,4 @@ +""" +nasdaq-protocols contains client side implementations of the various +publicly available protocols used by Nasdaq ecosystem. +""" diff --git a/src/nasdaq_protocols/common/message_queue.py b/src/nasdaq_protocols/common/message_queue.py index dab0376..49f3cb5 100644 --- a/src/nasdaq_protocols/common/message_queue.py +++ b/src/nasdaq_protocols/common/message_queue.py @@ -33,7 +33,11 @@ def __attrs_post_init__(self): self._msg_queue = asyncio.Queue() self.start_dispatching(self.on_msg_coro) - async def put(self, msg: Any): + async def put(self, msg: Any) -> None: + """ + put an entry into the queue. + :param msg: Any + """ await self._msg_queue.put(msg) async def get(self): @@ -51,9 +55,17 @@ async def get(self): return msg if msg else await self._blocking_read() def put_nowait(self, msg: Any): + """ + put an entry into the queue. + :param msg: Any + """ self._msg_queue.put_nowait(msg) - def get_nowait(self): + def get_nowait(self) -> Any | None: + """ + get an entry from the queue. This is a non-blocking call. + :return: entry from the queue or None if queue is empty. + """ if self._dispatcher_task: raise StateError(f'{self.session_id}-dispatcher, Dispatcher is running, cannot use get_no_wait') msg = None @@ -66,6 +78,17 @@ def get_nowait(self): @asynccontextmanager async def pause_dispatching(self): + """ + This is a context manager that pauses the dispatcher:: + + queue = DispatchableMessageQueue(session_id, on_msg_coro) + + queue.get() # will raise an exception + + async with queue.pause_dispatching(): + queue.get() # will not raise an exception + + """ if not self._dispatcher_task: raise StateError('Dispatcher is not running, cannot pause') self._dispatcher_task = await stop_task(self._dispatcher_task) @@ -76,7 +99,12 @@ async def pause_dispatching(self): self._dispatcher_task = asyncio.create_task(self._start_dispatching(), name=f'{self.session_id}-dispatcher') self.log.debug('%s> queue dispatcher resumed.', self.session_id) - def start_dispatching(self, on_msg_coro: DispatcherCoro): + def start_dispatching(self, on_msg_coro: DispatcherCoro) -> None: + """ + Start dispatching messages from the queue to the coro. + + :param on_msg_coro: + """ if self._dispatcher_task: raise StateError('Dispatcher is already running, cannot start') if on_msg_coro: @@ -84,13 +112,19 @@ def start_dispatching(self, on_msg_coro: DispatcherCoro): self._dispatcher_task = asyncio.create_task(self._start_dispatching(), name=f'{self.session_id}-dispatcher') self.log.debug('%s> queue dispatcher started.', self.session_id) - async def stop(self): + async def stop(self) -> None: + """ + Stop the queue. + """ if not self._closed: self._closed = True self._dispatcher_task = await stop_task(self._dispatcher_task) self._recv_task = await stop_task(self._recv_task) - def is_stopped(self): + def is_stopped(self) -> bool: + """ + :return: True if the queue is stopped. + """ return self._closed async def _start_dispatching(self): diff --git a/src/nasdaq_protocols/common/session.py b/src/nasdaq_protocols/common/session.py index 866bae8..3a647b0 100644 --- a/src/nasdaq_protocols/common/session.py +++ b/src/nasdaq_protocols/common/session.py @@ -31,8 +31,15 @@ @attrs.define(auto_attribs=True) class HearbeatMonitor(Stoppable): """ - Monitors for no activity for configured interval. - Activity is reported to this monitor by calling ping() + Monitor that trips the `on_no_activity_coro` if no activity is detected. + + Currently, activity is externally signalled by calling the `ping` method. + + :param session_id: The session id. + :param interval: interval in seconds at which the monitor checks for activity. + :param on_no_activity_coro: coroutine to be called when no activity is detected. + :param stop_when_no_activity: If True, the monitor stops when no activity is detected. + :param tolerate_missed_heartbeats: number of missed heartbeats to tolerate. """ session_id: Any = attrs.field(validator=Validators.not_none()) interval: float = attrs.field(validator=Validators.not_none()) @@ -46,17 +53,20 @@ def __attrs_post_init__(self): self._monitor_task = asyncio.create_task(self._start_monitor(), name=f'{self.session_id}-monitor') self.log.debug('%s> monitor started.', self.session_id) - def ping(self): + def ping(self) -> None: + """Ping the monitor.""" self._pinged = True - def is_running(self): + def is_running(self) -> bool: + """Returns True if the monitor is running.""" return self._monitor_task is not None and not self._monitor_task.done() - async def stop(self): + async def stop(self) -> None: + """Stop the monitor.""" self._monitor_task = await stop_task(self._monitor_task) - def is_stopped(self): - return self._monitor_task is None + def is_stopped(self) -> bool: + return not self.is_running() async def _start_monitor(self): missed_heartbeats = count(1) @@ -86,7 +96,9 @@ class Reader(Stoppable): A reader is responsible for parsing the received data from the transport and dispatching it to the on_msg_coro. - If the reader detects an end of session, then it signals the on_close_coro. + :param session_id: The session id. + :param on_msg_coro: coroutine to be called for every message parsed. + :param on_close_coro: coroutine to be called when the reader detects end of session. """ session_id: Any = attrs.field(validator=Validators.not_none()) on_msg_coro: OnMsgCoro = attrs.field(validator=Validators.not_none()) @@ -94,23 +106,40 @@ class Reader(Stoppable): @abc.abstractmethod async def on_data(self, data: bytes): - pass + """Called when data is received from the transport.""" @attrs.define(auto_attribs=True) class SessionId: + """ + A basic session id. + """ host: str = 'nohost' port: int = 0 - def set_transport(self, transport: asyncio.Transport): + def set_transport(self, transport: asyncio.Transport) -> None: + """Once the transport is available, the host and port are updated.""" self.host, self.port = transport.get_extra_info('peername') @logable @attrs.define(auto_attribs=True) class AsyncSession(asyncio.Protocol, abc.ABC): - """Abstract base class for async sessions.""" + """ + Abstract base class for async sessions. + + Once the transport is available, the session creates a new reader using the + `reader_factory` and starts parsing the incoming bytes. + By default, the session starts in a dispatching mode, meaning the incoming messages + are dispatched to the `on_msg_coro`. This can be changed by setting `dispatch_on_connect=False`. + + :param session_id: The session id. + :param reader_factory: A callable that returns a reader. + :param on_msg_coro: coroutine to be called when a message is received. + :param on_close_coro: coroutine to be called when the session is closed. + :param dispatch_on_connect: If True, the session starts with dispatching once connected. + """ session_id: SessionId = attrs.field(kw_only=True, validator=Validators.not_none()) reader_factory: ReaderFactory = attrs.field(kw_only=True, validator=Validators.not_none()) on_msg_coro: OnMsgCoro = attrs.field(kw_only=True, default=None) @@ -130,12 +159,23 @@ def __attrs_post_init__(self): # By default do not dispatch messages self._msg_queue = DispatchableMessageQueue(self.session_id) - async def receive_msg(self): - """Receive a message from the peer. This is a blocking call.""" + async def receive_msg(self) -> Any: + """ + Receive a message from the peer. This is a blocking call. + This call blocks until a new message is available. + + If the session is dispatching messages, then this call raises an exception. + + :return Any: The message received. + """ return await self._msg_queue.get() - def receive_msg_nowait(self): - """Receive a message from the peer. This is a non-blocking call.""" + def receive_msg_nowait(self) -> Any | None: + """ + Receive a message from the peer. This is a non-blocking call. + + :return Any: The message received. + """ return self._msg_queue.get_nowait() def is_active(self) -> bool: @@ -143,16 +183,29 @@ def is_active(self) -> bool: return not (self._closed or self._closing_task) def is_closed(self) -> bool: + """ + Returns True if the session is closed. + :return: + """ return self._closed - def initiate_close(self): - """Initiate close of the session.""" + def initiate_close(self) -> None: + """ + Initiate close of the session. + An asynchronous task is created which will close the session and all its + associates. + + Poll `is_closed` to check if the session is closed or use the + `on_close_coro` callback to be notified when the session is closed. + """ if self._closed or self._closing_task: return self._closing_task = asyncio.create_task(self.close(), name=f'asyncsession-close:{self.session_id}') async def close(self): - """Close the session, the session cannot be used after this call.""" + """ + Close the session, the session cannot be used after this call. + """ if not self._closed: self._closed = True await stop_task([ @@ -171,7 +224,7 @@ def start_heartbeats(self, local_hb_interval: int | float, remote_hb_interval: i """Starts the heartbeats for the session. - if the remote failed heartbeats, then the session is closed. - - if the local failed heartbeats, then `send_heartbeat` is called. + - if the local heartbeat timer expires, then `send_heartbeat` is called. """ self._local_hb_monitor = HearbeatMonitor( self.session_id, local_hb_interval, self.send_heartbeat, stop_when_no_activity=False @@ -181,16 +234,20 @@ def start_heartbeats(self, local_hb_interval: int | float, remote_hb_interval: i def start_dispatching(self): """ - By default, the session starts with dispatching switched-off. + By default, the session starts with dispatching switched-on. - Once application level login/handshake is established, this method - can be called to dispatch messages to the on_msg_coro. + If, the session is created with dispatching-off, then at any point in time + during the lifetime of this session, dispatching can be switched-on by calling + this method. """ self._msg_queue.start_dispatching(self.on_msg_coro) self.log.debug('%s> started dispatching', self.session_id) # asyncio.Protocol overloads. def connection_made(self, transport: asyncio.Transport): + """ + :meta private: + """ self.log.debug('%s> connected', self.session_id) self._transport = transport self.session_id.set_transport(self._transport) @@ -199,19 +256,31 @@ def connection_made(self, transport: asyncio.Transport): self.start_dispatching() def data_received(self, data): + """ + :meta private: + """ if self._remote_hb_monitor: self._remote_hb_monitor.ping() self._reader_task = asyncio.create_task(self._reader.on_data(data), name=f'asyncsession-ondata:{self.session_id}') def connection_lost(self, exc): + """ + :meta private: + """ self.log.debug('%s> connection lost', self.session_id) self.initiate_close() @abc.abstractmethod - def send_msg(self, msg: Serializable): - pass + def send_msg(self, msg: Serializable) -> None: + """ + Send a message to the peer. + :param msg: Any message that is serializable. + """ @abc.abstractmethod async def send_heartbeat(self): - pass + """ + Callback to send a heartbeat to the peer. + :meta private: + """ diff --git a/src/nasdaq_protocols/common/types.py b/src/nasdaq_protocols/common/types.py index 3736766..1bf09e0 100644 --- a/src/nasdaq_protocols/common/types.py +++ b/src/nasdaq_protocols/common/types.py @@ -35,8 +35,8 @@ def is_stopped(self): class StateError(RuntimeError): - pass + """Raised when an operation is attempted in an invalid state.""" class EndOfQueue(EOFError): - pass + """Raised when the end of the queue is reached.""" diff --git a/src/nasdaq_protocols/common/utils.py b/src/nasdaq_protocols/common/utils.py index 73fbec2..9afa112 100644 --- a/src/nasdaq_protocols/common/utils.py +++ b/src/nasdaq_protocols/common/utils.py @@ -34,7 +34,9 @@ async def _stop_task(task: asyncio.Task): def logable(target): - """decorator to add a logger to a class""" + """ + decorator that adds a log object to the class. + """ assert inspect.isclass(target) target.log = logging.getLogger(target.__name__) @@ -42,7 +44,9 @@ def logable(target): async def stop_task(tasks: _StopTaskTypes | list[_StopTaskTypes]) -> asyncio.Task | Stoppable | None: - """Cancel a task and wait for it to finish""" + """ + Cancel a task and wait for it to finish + """ if not isinstance(tasks, list): tasks = [tasks] diff --git a/src/nasdaq_protocols/soup/__init__.py b/src/nasdaq_protocols/soup/__init__.py index e5a4d83..31e5280 100644 --- a/src/nasdaq_protocols/soup/__init__.py +++ b/src/nasdaq_protocols/soup/__init__.py @@ -1,5 +1,22 @@ +""" +The SoupBinTCP protocol is a simple, lightweight, and fast protocol that +provides reliable, ordered, and error-checked delivery of messages between +client and server. It is designed for high-performance market data and order +entry applications. The protocol is based on the TCP/IP protocol suite and +uses TCP as its transport protocol. + +This module provides a SoupBinTCP client implementation that can be used to +connect to the SoupBinTCP servers. + +Though SoupBinTCP is meant for latency sensitive applications, there are +numerous times when the client application is not latency sensitive and +would like to talk to the soup server, Say in testing or writing a monitoring tool. + +In such cases, the client application can use the SoupBinTCP client provided by this module. +""" import asyncio from typing import Callable + from nasdaq_protocols import common from .core import * from .session import * @@ -9,12 +26,35 @@ async def connect_async(remote: tuple[str, int], # pylint: disable=too-many-arg user: str, passwd: str, session_id: str = '', - sequence: int = 0, + sequence: int = 1, on_msg_coro: OnSoupMsgCoro = None, on_close_coro: common.OnCloseCoro = None, session_factory: Callable[[], SoupClientSession] = None, client_heartbeat_interval: int = 10, - server_heartbeat_interval: int = 10): + server_heartbeat_interval: int = 10) -> SoupClientSession: + """ + Connect to the SoupBinTCP server and login. + + Using `:param sequence` the client can specify the sequence number of the next + message it expects to receive. The server will then send all messages with sequence + numbers greater than the specified sequence number. + + To connect to the start of the stream, specify sequence=1, which is the default. + To connect to the end of the stream, specify sequence=0, new messages will be received. + To connect to a specific message, specify the sequence number of the message. + + :param remote: tuple of host and port + :param user: Username to login + :param passwd: Password to login + :param session_id: Name of the session to join [Default=''] . + :param sequence: The sequence number. [Default=1] + :param on_msg_coro: callback, message from server. + :param on_close_coro: callback, connection closed . + :param session_factory: Factory to create a SoupClientSession. + :param client_heartbeat_interval: seconds between client heartbeats. + :param server_heartbeat_interval: seconds between server heartbeats. + :return: SoupClientSession + """ loop = asyncio.get_running_loop() def default_session_factory(): diff --git a/src/nasdaq_protocols/soup/core.py b/src/nasdaq_protocols/soup/core.py index a6e604b..559d575 100644 --- a/src/nasdaq_protocols/soup/core.py +++ b/src/nasdaq_protocols/soup/core.py @@ -1,8 +1,12 @@ +""" +nasdaq_protocols.soup.core module contains the implementation of the +soup messages. +""" import enum import struct import attrs -from nasdaq_protocols.common import logable, Serializable +from nasdaq_protocols import common __all__ = [ @@ -27,6 +31,8 @@ class InvalidSoupMessage(ValueError): class LoginRejectReason(enum.Enum): + """Login Reject Reason sent from server in case of login failure.""" + NOT_AUTHORIZED = 'A' SESSION_NOT_AVAILABLE = 'S' @@ -36,9 +42,18 @@ def get(cls, reason: str): return reason if isinstance(reason, cls) else LoginRejectReason(reason) -@logable -class SoupMessage(Serializable): - """Base class for all soup messages.""" +@common.logable +class SoupMessage(common.Serializable): + """ + Base class for all soup messages. + + Give raw bytes use this class to unpack the bytes to the corresponding soup message:: + + input_bytes = b'\x00\x1fAtest 2 ' + soup_msg = SoupMessage.from_bytes(input_bytes) + type(soup_msg) + + """ ClassByIndicator = {} Format = '!h c' @@ -83,6 +98,15 @@ def is_logout(self): @attrs.define(slots=False, auto_attribs=True) class LoginRequest(SoupMessage, indicator='L', description='Login Request'): + """ + SoupBinTCP Login Request Message. + + :param user: Username to login + :param passwd: Password to login + :param session: Name of the session to join [Default=''] . + :param sequence: The sequence number. [Default=1] + """ + Format = '!h c 6s 10s 10s 20s' Length = 47 @@ -91,7 +115,12 @@ class LoginRequest(SoupMessage, indicator='L', description='Login Request'): session: str sequence: str - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + + :return: bytes + """ return struct.pack(LoginRequest.Format, LoginRequest.Length, self.Indicator.encode('ascii'), @@ -111,13 +140,23 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class LoginAccepted(SoupMessage, indicator='A', description='Login Accepted'): + """ + SoupBinTCP Login Accepted Message. + + :param session_id: Name of the session joined [Default=''] . + :param sequence: The next sequence number. + """ Format = '!h c 10s 20s' Len = 31 session_id: str sequence: int - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + :return: bytes + """ return struct.pack(LoginAccepted.Format, LoginAccepted.Len, self.Indicator.encode('ascii'), _pack(self.session_id, 10), @@ -131,12 +170,21 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class LoginRejected(SoupMessage, indicator='J', description='Login Rejected'): + """ + SoupBinTCP Login Rejected Message. + + :param reason: Reason for login failure. Refer `LoginRejectReason` + """ Format = '!h c c' Length = 2 reason: LoginRejectReason = attrs.field(converter=LoginRejectReason.get) - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + :return: bytes + """ return struct.pack(LoginRejected.Format, LoginRejected.Length, self.Indicator.encode('ascii'), @@ -150,9 +198,18 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class SequencedData(SoupMessage, indicator='S', description='Sequenced Data'): + """ + SoupBinTCP Sequenced Data Message. + + :param data: The application payload sent by the server. + """ data: bytes - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + :return: bytes + """ msg = struct.pack(SoupMessage.Format, len(self.data)+1, self.Indicator.encode('ascii')) @@ -166,9 +223,18 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class UnSequencedData(SoupMessage, indicator='U', description='UnSequenced Data'): + """ + SoupBinTCP Unsequenced Data Message. + + :param data: The application payload to be sent to server. + """ data: bytes - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + :return: bytes + """ msg = struct.pack(SoupMessage.Format, len(self.data)+1, self.Indicator.encode('ascii')) @@ -182,9 +248,18 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class Debug(SoupMessage, indicator='+', description='Debug'): + """ + SoupBinTCP Debug Message. + + :param msg: The debug message. + """ msg: str - def to_bytes(self): + def to_bytes(self) -> bytes: + """ + Pack the soup message to binary format + :return: bytes + """ msg = struct.pack(SoupMessage.Format, len(self.msg) + 1, self.Indicator.encode('ascii')) @@ -198,24 +273,42 @@ def unpack(cls, bytes_): @attrs.define(slots=False, auto_attribs=True) class ClientHeartbeat(SoupMessage, indicator='R', description='Client Heartbeat'): + """ + SoupBinTCP Client Heartbeat Message. + """ + def is_heartbeat(self): return True @attrs.define(slots=False, auto_attribs=True) class ServerHeartbeat(SoupMessage, indicator='H', description='Server Heartbeat'): + """ + SoupBinTCP Server Heartbeat Message. + """ + def is_heartbeat(self): return True @attrs.define(slots=False, auto_attribs=True) class EndOfSession(SoupMessage, indicator='Z', description='End of Session'): + """ + SoupBinTCP End of Session Message. + + This message is sent from server to indicate the soup stream is now closed. + """ def is_logout(self): return True @attrs.define(slots=False, auto_attribs=True) class LogoutRequest(SoupMessage, indicator='O', description='LogoutRequest'): + """ + SoupBinTCP Logout Request Message. + + This message is initiated by the client to sever for graceful session logoff. + """ def is_logout(self): return True diff --git a/src/nasdaq_protocols/soup/session.py b/src/nasdaq_protocols/soup/session.py index 161972e..da598d8 100644 --- a/src/nasdaq_protocols/soup/session.py +++ b/src/nasdaq_protocols/soup/session.py @@ -1,3 +1,6 @@ +""" +nasdaq_protocols.soup.session contains implementation of the soup session. +""" import abc import asyncio from typing import Any, Awaitable, Callable, ClassVar @@ -23,6 +26,13 @@ @attrs.define(auto_attribs=True) class SoupSessionId(common.SessionId): + """ + Identifier for a soup session. + + :param session_type: The type of the session, either 'client' or 'server'. + :param user: The username. + :param session: The session id. + """ session_type: str = 'norole' user: str = 'nouser' session: str = 'nosession' @@ -30,6 +40,13 @@ class SoupSessionId(common.SessionId): def update(self, msg: LoginRequest | LoginAccepted, transport: asyncio.Transport | None = None): + """ + Update the session id with the more information as and when it is available. + + :param msg: LoginRequest or LoginAccepted message. + :param transport: asyncio.Transport object. + :return: self + """ if transport: self.set_transport(transport) elif isinstance(msg, LoginRequest): @@ -44,7 +61,18 @@ def __str__(self): @attrs.define(auto_attribs=True) @common.logable class SoupSession(common.AsyncSession, abc.ABC): - SessionType: ClassVar[str] = 'base' + """ + Base class for SoupBinTCP[server, client] session. + + :param on_msg_coro: Coroutine to be called when a message is received. + :param on_close_coro: Coroutine to be called when the session is closed. + :param sequence: The sequence number. [Default=1] + :param client_heartbeat_interval: The client heartbeat interval in seconds. [Default=10] + :param server_heartbeat_interval: The server heartbeat interval in seconds. [Default=10] + :param session_id: The session id. + """ + + SessionType: ClassVar[str] = None on_msg_coro: OnSoupMsgCoro = None on_close_coro: common.OnCloseCoro = None @@ -55,16 +83,24 @@ class SoupSession(common.AsyncSession, abc.ABC): reader_factory: common.ReaderFactory = attrs.field(init=False, default=SoupMessageReader) def __init_subclass__(cls, **kwargs): # pylint: disable=arguments-renamed + if cls.SessionType: + return if 'session_type' in kwargs: cls.SessionType = kwargs['session_type'] else: + cls.log.info('Setting base') cls.SessionType = 'base' def __attrs_post_init__(self): self.session_id.session_type = self.SessionType super().__attrs_post_init__() - def send_msg(self, msg: SoupMessage): + def send_msg(self, msg: SoupMessage) -> None: + """ + Send a soup message to the server. + + :param msg: SoupMessage object. + """ bytes_ = msg.to_bytes() self._transport.write(bytes_) self.log.debug('%s> sent %s', self.session_id, str(bytes_)) @@ -75,22 +111,48 @@ def send_msg(self, msg: SoupMessage): self.log.debug('%s> sent sequenced message, seq = %d', self.session_id, self.sequence) self.sequence += 1 - def send_debug(self, text: str): + def send_debug(self, text: str) -> None: + """ + Send a debug message to the peer. + + :param text: debug text + """ self.send_msg(Debug(text)) - async def logout(self): + def logout(self) -> None: + """ + Logout. + + The session is closed after sending the logout request. + :return: + """ self.log.debug('%s> logging out', self.session_id) self.send_msg(LogoutRequest()) - await self.close() + self.initiate_close() @attrs.define(auto_attribs=True) @common.logable class SoupClientSession(SoupSession, session_type='client'): + """ + SoupBinTCP client session. + + Upon successful connecting to the soup server, the client session is instantiated. + """ dispatch_on_connect: bool = False async def login(self, msg: LoginRequest): + """ + Login to the soup server. + + This is supposed to be the first message to be sent to server upon connection successful. + + :param msg: LoginRequest message. + :return: self + :raises ConnectionRefusedError: If the server rejects the login request. + """ self.log.debug('%s> logging in', self.session_id) + self.session_id.update(msg) self.send_msg(msg) reply = await self.receive_msg() @@ -104,17 +166,37 @@ async def login(self, msg: LoginRequest): self.log.debug('%s> session established, sequence = %d', self.session_id, self.sequence) self.start_heartbeats(self.client_heartbeat_interval, self.server_heartbeat_interval) self.start_dispatching() + return self async def send_heartbeat(self): + """ + Send heartbeat to the server. + + :meta private: + """ self.send_msg(ClientHeartbeat()) def send_unseq_data(self, data: bytes): + """ + Send unsequenced data to the server. + :param data: application payload + """ self.send_msg(UnSequencedData(data)) @attrs.define(auto_attribs=True) @common.logable class SoupServerSession(SoupSession, session_type='server'): + """ + Base class for all soup server sessions. + + Any class that implements a soup server session must inherit from this class. + and implement the following methods: + + - on_login + - on_unsequenced + + """ on_msg_coro: OnSoupMsgCoro = attrs.field(init=False) _logged_in: bool = attrs.field(init=False, default=False) @@ -125,24 +207,48 @@ def __attrs_post_init__(self): @abc.abstractmethod async def on_login(self, msg: LoginRequest) -> LoginAccepted | LoginRejected: - pass + """ + Handle the login request from the client. + + :param msg: LoginRequest message. + :return: LoginAccepted or LoginRejected message. + """ @abc.abstractmethod async def on_unsequenced(self, msg: UnSequencedData) -> None: - pass + """ + Handle the unsequenced data from the client. - def send_seq_msg(self, data: bytes): + :param msg: UnSequencedData message. + """ + + def send_seq_msg(self, data: bytes) -> None: + """ + Send sequenced data to the client. + + :param data: application payload + """ if not isinstance(data, SequencedData): data = SequencedData(data) self.send_msg(data) def end_session(self): + """ + End the session. + """ self.send_msg(EndOfSession()) + self.initiate_close() async def on_debug(self, msg: Debug) -> None: self.log.info('%s> ++ client debug : %s', msg) async def send_heartbeat(self): + """ + Send heartbeat to the client. + + This is called automatically by the session when the heartbeat interval expires. + :meta private: + """ self.send_msg(ServerHeartbeat()) async def _on_msg(self, msg: SoupMessage) -> None: @@ -153,6 +259,9 @@ async def _on_msg(self, msg: SoupMessage) -> None: await self.on_unsequenced(msg) elif isinstance(msg, Debug): await self.on_debug(msg) + elif isinstance(msg, LogoutRequest): + self.log.debug('%s> client logged out', self.session_id) + await self.close() async def _handle_login(self, msg: LoginRequest) -> None: if self._logged_in: diff --git a/tests/test_common_asyncsession.py b/tests/test_common_asyncsession.py index b64e761..0cbaef1 100644 --- a/tests/test_common_asyncsession.py +++ b/tests/test_common_asyncsession.py @@ -106,8 +106,8 @@ async def test_stop_initiated_by_client(mock_server_session, client_session): async def test_server_failed_heartbeat_connection_is_closed(mock_server_session, client_session): _, server_session = mock_server_session - client_session.start_heartbeats(10, 0.1) - await asyncio.sleep(2) + client_session.start_heartbeats(10, 0.01) + await asyncio.sleep(0.1) # test client session is closed assert client_session.is_closed() diff --git a/tests/test_soup_session.py b/tests/test_soup_session.py index 4c4e79a..8f39483 100644 --- a/tests/test_soup_session.py +++ b/tests/test_soup_session.py @@ -1,50 +1,117 @@ -import attrs +import asyncio import pytest -import logging -from unittest.mock import MagicMock +from nasdaq_protocols import common from nasdaq_protocols import soup -from .mocks import mock_server_session +from nasdaq_protocols.soup import LoginRequest, LoginAccepted, LoginRejected -logger = logging.getLogger(__name__) +class SoupServerTestSession(soup.SoupServerSession, session_type='server'): + async def on_login(self, msg: LoginRequest) -> LoginAccepted | LoginRejected: + if msg.user == 'test-u' and msg.password == 'test-p': + return LoginAccepted('session', int(msg.sequence)) + else: + return LoginRejected(soup.LoginRejectReason.NOT_AUTHORIZED) + async def on_unsequenced(self, msg: soup.UnSequencedData): + reply = msg.data.decode('ascii') + '-ack' + self.send_msg(soup.SequencedData(reply.encode('ascii'))) -@attrs.define(slots=False, auto_attribs=True) -class MockServerSession(soup.SoupServerSession, session_type='mock_server'): - on_login_mock: MagicMock = None - on_unsequenced_mock: MagicMock = None + def generate_load(self, number_of_messages): + for i in range(number_of_messages): + self.send_msg(soup.SequencedData(f'msg-{i}'.encode('ascii'))) + self.send_msg(soup.UnSequencedData('end'.encode('ascii'))) - async def on_login(self, msg: soup.LoginRequest) -> soup.LoginAccepted | soup.LoginRejected: - if self.on_login_mock: - return self.on_login_mock(msg) - async def on_unsequenced(self, msg: soup.UnSequencedData) -> None: - if self.on_unsequenced_mock: - return self.on_unsequenced_mock(msg) +@pytest.fixture(scope='function') +async def test_soup_server_session(unused_tcp_port): + session = SoupServerTestSession() + server, serving_task = await common.start_server(('127.0.0.1', unused_tcp_port), lambda: session) + yield unused_tcp_port, session + retry = 0 + while not session.is_closed() and retry < 5: + await asyncio.sleep(0.001) -def match_soup_msg_type(type_): - def match(data): - soup_message = soup.SoupMessage.from_bytes(data) - return isinstance(soup_message, type_) - return match + assert session.is_closed() + await common.stop_task(serving_task) -@pytest.mark.asyncio -async def test_server_rejected_login(mock_server_session): - port, server_session = mock_server_session - - server_session.when(match_soup_msg_type(soup.LoginRequest))\ - .do(lambda x: server_session.send(soup.LoginRejected(soup.LoginRejectReason.NOT_AUTHORIZED))) +@pytest.mark.asyncio +async def test_server_rejected_login(test_soup_server_session): + port, server_session = test_soup_server_session with pytest.raises(ConnectionRefusedError): - logger.info('going to connect') client_session = await soup.connect_async( ('127.0.0.1', port), - 'test-u', - 'test-password', + 'nouser', + 'nopwd', 'session' ) assert client_session is None + + +@pytest.mark.asyncio +async def test_server_accepted_login(test_soup_server_session): + port, server_session = test_soup_server_session + + client_session = await soup.connect_async( + ('127.0.0.1', port), + 'test-u', + 'test-p', + 'session' + ) + assert client_session is not None + + client_session.logout() + + +@pytest.mark.asyncio +async def test_client_server_communicate(test_soup_server_session): + port, server_session = test_soup_server_session + + client_session = await soup.connect_async( + ('127.0.0.1', port), + 'test-u', + 'test-p', + 'session' + ) + assert client_session is not None + + for i in range(1, 10): + test_data = f'hello-{i}'.encode() + client_session.send_msg(soup.UnSequencedData(test_data)) + reply = await client_session.receive_msg() + assert isinstance(reply, soup.SequencedData) + assert reply.data == test_data + b'-ack' + + client_session.logout() + + +@pytest.mark.asyncio +async def test_server_streaming_client_uses_dispatcher(test_soup_server_session): + port, server_session = test_soup_server_session + + closed = asyncio.Event() + + async def on_msg(msg): + if msg.data == b'end': + server_session.end_session() + + async def on_close(): + closed.set() + + client_session = await soup.connect_async( + ('127.0.0.1', port), + 'test-u', + 'test-p', + 'session', + on_msg_coro=on_msg, + on_close_coro=on_close + ) + assert client_session is not None + + server_session.generate_load(100) + + await asyncio.wait_for(closed.wait(), 1)