Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Commit

Permalink
refactor(async): Various edits
Browse files Browse the repository at this point in the history
Remove reader/writer parameters from AAA methods
Refactor flow_control decorator, as we are no more passing reader/writer instances
Refactor tests
  • Loading branch information
nekonekun committed Aug 1, 2023
1 parent 85cf61f commit fa4527c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 108 deletions.
56 changes: 18 additions & 38 deletions tacacs_plus/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,24 @@ def version(self):
return (self.version_max * 0x10) + self.version_min

@contextlib.asynccontextmanager
async def flow_control(self, reader=None, writer=None):
# we do not need to create new reader/writer instances
# if both are provided
existing = bool(reader and writer)
if existing:
yielded_reader, yielded_writer = reader, writer
async def flow_control(self):
if self.family == socket.AF_INET:
conn = (self.host, self.port)
else:
if self.family == socket.AF_INET:
conn = (self.host, self.port)
else:
# For AF_INET6 address family, a four-tuple (host, port,
# flowinfo, scopeid) is used
conn = (self.host, self.port, 0, 0)
sock = socket.socket(self.family, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(conn)
yielded_reader, yielded_writer = await asyncio.open_connection(sock=sock)
# For AF_INET6 address family, a four-tuple (host, port,
# flowinfo, scopeid) is used
conn = (self.host, self.port, 0, 0)
sock = socket.socket(self.family, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(conn)

reader, writer = await asyncio.open_connection(sock=sock)
try:
yield yielded_reader, yielded_writer
yield reader, writer
finally:
if not existing:
# if reader/writer were not provided -
# we need to properly close socket and writer
yielded_writer.close()
await yielded_writer.wait_closed()
sock.close()
writer.close()
await writer.wait_closed()
sock.close()

async def send(self, body, req_type, seq_no=1, reader=None, writer=None):
"""
Expand Down Expand Up @@ -191,8 +183,6 @@ async def authenticate(
chap_challenge=None,
rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR,
port=TAC_PLUS_VIRTUAL_PORT,
reader=None,
writer=None
):
"""
Authenticate to a TACACS+ server with a username and password.
Expand All @@ -207,8 +197,6 @@ async def authenticate(
:param chap_challenge: challenge value when authen_type == 'chap'
:param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR
:param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT
:param reader: asyncio.StreamReader instance
:param writer: asyncio.StreamWriter instance
:return: TACACSAuthenticationReply
:raises: socket.timeout, socket.error
"""
Expand All @@ -234,7 +222,7 @@ async def authenticate(
start_data += chap_challenge.encode()
data_to_md5 = (chap_ppp_id + password + chap_challenge).encode()
start_data += md5(data_to_md5).digest()
async with self.flow_control(reader, writer) as (reader, writer):
async with self.flow_control() as (reader, writer):
packet = await self.send(
TACACSAuthenticationStart(
username,
Expand Down Expand Up @@ -283,8 +271,6 @@ async def authorize(
priv_lvl=TAC_PLUS_PRIV_LVL_MIN,
rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR,
port=TAC_PLUS_VIRTUAL_PORT,
reader=None,
writer=None
):
"""
Authorize with a TACACS+ server.
Expand All @@ -297,14 +283,12 @@ async def authorize(
:param priv_lvl: Minimal Required priv_lvl.
:param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR
:param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT
:param reader: asyncio.StreamReader instance
:param writer: asyncio.StreamWriter instance
:return: TACACSAuthenticationReply
:raises: socket.timeout, socket.error
"""
if arguments is None:
arguments = []
async with self.flow_control(reader, writer) as (reader, writer):
async with self.flow_control() as (reader, writer):
packet = await self.send(
TACACSAuthorizationStart(
username,
Expand Down Expand Up @@ -349,8 +333,6 @@ async def account(
priv_lvl=TAC_PLUS_PRIV_LVL_MIN,
rem_addr=TAC_PLUS_VIRTUAL_REM_ADDR,
port=TAC_PLUS_VIRTUAL_PORT,
reader=None,
writer=None
):
"""
Account with a TACACS+ server.
Expand All @@ -366,14 +348,12 @@ async def account(
:param priv_lvl: Minimal Required priv_lvl.
:param rem_addr: AAA request source, default to TAC_PLUS_VIRTUAL_REM_ADDR
:param port: AAA port, default to TAC_PLUS_VIRTUAL_PORT
:param reader: asyncio.StreamReader instance
:param writer: asyncio.StreamWriter instance
:return: TACACSAccountingReply
:raises: socket.timeout, socket.error
"""
if arguments is None:
arguments = []
async with self.flow_control(reader, writer) as (reader, writer):
async with self.flow_control() as (reader, writer):
packet = await self.send(
TACACSAccountingStart(
username,
Expand Down
Loading

0 comments on commit fa4527c

Please sign in to comment.