diff --git a/tacacs_plus/async_client.py b/tacacs_plus/async_client.py index 268e1d1..f640d30 100644 --- a/tacacs_plus/async_client.py +++ b/tacacs_plus/async_client.py @@ -89,32 +89,24 @@ def version(self): return (self.version_max * 0x10) + self.version_min @contextlib.asynccontextmanager - async def flow_control(self, reader=None, writer=None): - # we do not need to create new reader/writer instances - # if both are provided - existing = bool(reader and writer) - if existing: - yielded_reader, yielded_writer = reader, writer + async def flow_control(self): + if self.family == socket.AF_INET: + conn = (self.host, self.port) else: - if self.family == socket.AF_INET: - conn = (self.host, self.port) - else: - # For AF_INET6 address family, a four-tuple (host, port, - # flowinfo, scopeid) is used - conn = (self.host, self.port, 0, 0) - sock = socket.socket(self.family, socket.SOCK_STREAM) - sock.settimeout(self.timeout) - sock.connect(conn) - yielded_reader, yielded_writer = await asyncio.open_connection(sock=sock) + # For AF_INET6 address family, a four-tuple (host, port, + # flowinfo, scopeid) is used + conn = (self.host, self.port, 0, 0) + sock = socket.socket(self.family, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(conn) + + reader, writer = await asyncio.open_connection(sock=sock) try: - yield yielded_reader, yielded_writer + yield reader, writer finally: - if not existing: - # if reader/writer were not provided - - # we need to properly close socket and writer - yielded_writer.close() - await yielded_writer.wait_closed() - sock.close() + writer.close() + await writer.wait_closed() + sock.close() async def send(self, body, req_type, seq_no=1, reader=None, writer=None): """ @@ -191,8 +183,6 @@ async def authenticate( chap_challenge=None, rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR, port=TAC_PLUS_VIRTUAL_PORT, - reader=None, - writer=None ): """ Authenticate to a TACACS+ server with a username and password. @@ -207,8 +197,6 @@ async def authenticate( :param chap_challenge: challenge value when authen_type == 'chap' :param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR :param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT - :param reader: asyncio.StreamReader instance - :param writer: asyncio.StreamWriter instance :return: TACACSAuthenticationReply :raises: socket.timeout, socket.error """ @@ -234,7 +222,7 @@ async def authenticate( start_data += chap_challenge.encode() data_to_md5 = (chap_ppp_id + password + chap_challenge).encode() start_data += md5(data_to_md5).digest() - async with self.flow_control(reader, writer) as (reader, writer): + async with self.flow_control() as (reader, writer): packet = await self.send( TACACSAuthenticationStart( username, @@ -283,8 +271,6 @@ async def authorize( priv_lvl=TAC_PLUS_PRIV_LVL_MIN, rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR, port=TAC_PLUS_VIRTUAL_PORT, - reader=None, - writer=None ): """ Authorize with a TACACS+ server. @@ -297,14 +283,12 @@ async def authorize( :param priv_lvl: Minimal Required priv_lvl. :param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR :param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT - :param reader: asyncio.StreamReader instance - :param writer: asyncio.StreamWriter instance :return: TACACSAuthenticationReply :raises: socket.timeout, socket.error """ if arguments is None: arguments = [] - async with self.flow_control(reader, writer) as (reader, writer): + async with self.flow_control() as (reader, writer): packet = await self.send( TACACSAuthorizationStart( username, @@ -349,8 +333,6 @@ async def account( priv_lvl=TAC_PLUS_PRIV_LVL_MIN, rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR, port=TAC_PLUS_VIRTUAL_PORT, - reader=None, - writer=None ): """ Account with a TACACS+ server. @@ -366,14 +348,12 @@ async def account( :param priv_lvl: Minimal Required priv_lvl. :param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR :param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT - :param reader: asyncio.StreamReader instance - :param writer: asyncio.StreamWriter instance :return: TACACSAccountingReply :raises: socket.timeout, socket.error """ if arguments is None: arguments = [] - async with self.flow_control(reader, writer) as (reader, writer): + async with self.flow_control() as (reader, writer): packet = await self.send( TACACSAccountingStart( username, diff --git a/tests/test_async_client.py b/tests/test_async_client.py index 3602808..0221357 100644 --- a/tests/test_async_client.py +++ b/tests/test_async_client.py @@ -1,3 +1,4 @@ +import asyncio import pytest from hashlib import md5 import socket @@ -48,20 +49,30 @@ def write(self, data): async def drain(self): pass - async def read(self, size: int = 0): - return self.buff.read(size or None) + def close(self): + pass + + async def wait_closed(self): + pass -class FakePair: - def __init__(self, response_packets): - self.reader = FakeReader(io.BytesIO(response_packets)) - self.writer = FakeWriter(io.BytesIO()) +@pytest.fixture +def patch_connection(monkeypatch, packets): + async def open_connection(*args, **kwargs): + return fake_reader, fake_writer + reader_buff = io.BytesIO(packets) + reader_buff.seek(0) + fake_reader = FakeReader(reader_buff) -@pytest.fixture(scope='function') -def fake_pair(request): - packets = request.node.callspec.params.get('packets') - return FakePair(packets) + writer_buff = io.BytesIO() + fake_writer = FakeWriter(writer_buff) + + monkeypatch.setattr(asyncio, 'open_connection', open_connection) + monkeypatch.setattr(socket.socket, 'connect', lambda self, conn: None) + monkeypatch.setattr(socket.socket, 'close', lambda self: None) + + return writer_buff # test client send @@ -81,22 +92,22 @@ def fake_pair(request): ], ) @pytest.mark.asyncio -async def test_client_socket_send(fake_pair, packets, state): +async def test_client_socket_send(patch_connection, packets, state): body = TACACSAuthenticationStart('user123', TAC_PLUS_AUTHEN_TYPE_ASCII) client = TACACSClient('127.0.0.1', 49, None, session_id=12345) - fake_pair.reader.buff.seek(0) - packet = await client.send( - body, TAC_PLUS_AUTHEN, reader=fake_pair.reader, writer=fake_pair.writer - ) + async with client.flow_control() as (reader, writer): + packet = await client.send( + body, TAC_PLUS_AUTHEN, reader=reader, writer=writer + ) assert isinstance(packet, TACACSPacket) reply = TACACSAuthenticationReply.unpacked(packet.body) assert getattr(reply, state) is True # the first 12 bytes of the packet represent the header - fake_pair.writer.buff.seek(0) + patch_connection.seek(0) sent_header, sent_body = ( - await fake_pair.writer.read(12), - await fake_pair.writer.read(), + patch_connection.read(12), + patch_connection.read(), ) body_length = TACACSHeader.unpacked(sent_header).length assert len(sent_body) == body_length @@ -108,15 +119,13 @@ async def test_client_socket_send(fake_pair, packets, state): [AUTHENTICATE_HEADER_WRONG + b'\x06\x07\x00\x00\x00\x00\x00'], ) @pytest.mark.asyncio -async def test_client_socket_send_wrong_headers(fake_pair, packets): +async def test_client_socket_send_wrong_headers(patch_connection, packets): body = TACACSAuthenticationStart('user123', TAC_PLUS_AUTHEN_TYPE_ASCII) client = TACACSClient('127.0.0.1', 49, None, session_id=12345) with pytest.raises(socket.error): await client.send( body, TAC_PLUS_AUTHEN, - reader=fake_pair.reader, - writer=fake_pair.writer, ) @@ -131,7 +140,7 @@ async def test_client_socket_send_wrong_headers(fake_pair, packets): ], ) @pytest.mark.asyncio -async def test_authenticate_ascii(fake_pair, packets): +async def test_authenticate_ascii(patch_connection, packets): """ client -> AUTHSTART (username) STATUS_GETPASS <- server @@ -140,14 +149,14 @@ async def test_authenticate_ascii(fake_pair, packets): """ client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authenticate( - 'username', 'pass', reader=fake_pair.reader, writer=fake_pair.writer + 'username', 'pass', ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) - first_body = fake_pair.writer.buff.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthenticationStart( 'username', TAC_PLUS_AUTHEN_TYPE_ASCII @@ -155,11 +164,11 @@ async def test_authenticate_ascii(fake_pair, packets): == first_body ) - second_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + second_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) assert second_header.seq_no > first_header.seq_no - second_body = await fake_pair.writer.read() + second_body = patch_connection.read() assert TACACSAuthenticationContinue('pass').packed == second_body @@ -170,7 +179,7 @@ async def test_authenticate_ascii(fake_pair, packets): ], # auth_valid ) @pytest.mark.asyncio -async def test_authenticate_pap(fake_pair, packets): +async def test_authenticate_pap(patch_connection, packets): """ client -> AUTHSTART (user+pass) STATUS_PASS <- server @@ -180,15 +189,13 @@ async def test_authenticate_pap(fake_pair, packets): 'username', 'pass', authen_type=TAC_PLUS_AUTHEN_TYPE_PAP, - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 1) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthenticationStart( 'username', TAC_PLUS_AUTHEN_TYPE_PAP, data='pass'.encode() @@ -204,7 +211,7 @@ async def test_authenticate_pap(fake_pair, packets): ], # auth_valid ) @pytest.mark.asyncio -async def test_authenticate_chap(fake_pair, packets): +async def test_authenticate_chap(patch_connection, packets): """ client -> AUTHSTART user+md5challenge(pass) STATUS_PASS <- server @@ -216,15 +223,13 @@ async def test_authenticate_chap(fake_pair, packets): authen_type=TAC_PLUS_AUTHEN_TYPE_CHAP, chap_ppp_id='A', chap_challenge='challenge', - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 1) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthenticationStart( 'username', @@ -239,20 +244,18 @@ async def test_authenticate_chap(fake_pair, packets): 'packets', [AUTHORIZE_HEADER + b'\x06\x01\x00\x00\x00\x00\x00'] ) @pytest.mark.asyncio -async def test_authorize_ascii(fake_pair, packets): +async def test_authorize_ascii(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authorize( 'username', arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthorizationStart( 'username', @@ -269,21 +272,19 @@ async def test_authorize_ascii(fake_pair, packets): 'packets', [AUTHORIZE_HEADER + b'\x06\x01\x00\x00\x00\x00\x00'] ) @pytest.mark.asyncio -async def test_authorize_pap(fake_pair, packets): +async def test_authorize_pap(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authorize( 'username', arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], authen_type=TAC_PLUS_AUTHEN_TYPE_PAP, - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthorizationStart( 'username', @@ -300,21 +301,19 @@ async def test_authorize_pap(fake_pair, packets): 'packets', [AUTHORIZE_HEADER + b'\x06\x01\x00\x00\x00\x00\x00'] ) @pytest.mark.asyncio -async def test_authorize_chap(fake_pair, packets): +async def test_authorize_chap(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authorize( 'username', arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], authen_type=TAC_PLUS_AUTHEN_TYPE_CHAP, - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAuthorizationStart( 'username', @@ -332,21 +331,19 @@ async def test_authorize_chap(fake_pair, packets): 'packets', [ACCOUNT_HEADER + b'\x06\x00\x00\x00\x00\x01\x00\x00'] ) @pytest.mark.asyncio -async def test_account_start(fake_pair, packets): +async def test_account_start(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.account( 'username', TAC_PLUS_ACCT_FLAG_START, arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert reply.valid - fake_pair.writer.buff.seek(0) - first_header = TACACSHeader.unpacked(await fake_pair.writer.read(12)) + patch_connection.seek(0) + first_header = TACACSHeader.unpacked(patch_connection.read(12)) assert (first_header.version_max, first_header.version_min) == (12, 0) - first_body = await fake_pair.writer.read(first_header.length) + first_body = patch_connection.read(first_header.length) assert ( TACACSAccountingStart( 'username', @@ -365,15 +362,13 @@ async def test_account_start(fake_pair, packets): [AUTHORIZE_HEADER + b'\x12\x01\x01\x00\x00\x00\x00\x0bpriv-lvl=15'], ) @pytest.mark.asyncio -async def test_authorize_equal_priv_lvl(fake_pair, packets): +async def test_authorize_equal_priv_lvl(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authorize( 'username', arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], authen_type=TAC_PLUS_AUTHEN_TYPE_PAP, priv_lvl=TAC_PLUS_PRIV_LVL_MAX, - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert ( reply.valid @@ -385,15 +380,13 @@ async def test_authorize_equal_priv_lvl(fake_pair, packets): [AUTHORIZE_HEADER + b'\x11\x01\x01\x00\x00\x00\x00\x0bpriv-lvl=1'], ) @pytest.mark.asyncio -async def test_authorize_lesser_priv_lvl(fake_pair, packets): +async def test_authorize_lesser_priv_lvl(patch_connection, packets): client = TACACSClient('127.0.0.1', 49, None, session_id=12345) reply = await client.authorize( 'username', arguments=[b'service=shell', b'cmd=show', b'cmdargs=version'], authen_type=TAC_PLUS_AUTHEN_TYPE_PAP, priv_lvl=TAC_PLUS_PRIV_LVL_MAX, - reader=fake_pair.reader, - writer=fake_pair.writer, ) assert ( not reply.valid