Skip to content

Commit

Permalink
test and type improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
cgevans committed Dec 3, 2021
1 parent d2af2be commit 18acac5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 40 deletions.
38 changes: 15 additions & 23 deletions src/qslib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class BaseStatus(ABC):
@classmethod
@property
@abstractmethod
def _comlist(cls: Type[T]) -> dict[str, tuple[bytes, Callable[[Any], Any]]]:
def _comlist(
cls: Type[T],
) -> dict[str, tuple[bytes, Callable[[Any], Any]]]: # pragma: no cover
...

@classmethod
@property
@abstractmethod
def _com(cls: Type[T]) -> bytes:
@abstractmethod # pragma: no cover
def _com(cls: Type[T]) -> bytes: # pragma: no cover
...

@classmethod
Expand Down Expand Up @@ -105,39 +107,29 @@ class AccessLevel(Enum):
Full = "Full"

def __gt__(self, other: object) -> bool:
if isinstance(other, str):
if not isinstance(other, AccessLevel):
other = AccessLevel(other)
if isinstance(other, AccessLevel):
return _accesslevel_order[self.value] > _accesslevel_order[other.value]
raise NotImplementedError
return _accesslevel_order[self.value] > _accesslevel_order[other.value]

def __ge__(self, other: object) -> bool:
if isinstance(other, str):
if not isinstance(other, AccessLevel):
other = AccessLevel(other)
if isinstance(other, AccessLevel):
return _accesslevel_order[self.value] >= _accesslevel_order[other.value]
raise NotImplementedError
return _accesslevel_order[self.value] >= _accesslevel_order[other.value]

def __lt__(self, other: object) -> bool:
if isinstance(other, str):
if not isinstance(other, AccessLevel):
other = AccessLevel(other)
if isinstance(other, AccessLevel):
return _accesslevel_order[self.value] < _accesslevel_order[other.value]
raise NotImplementedError
return _accesslevel_order[self.value] < _accesslevel_order[other.value]

def __le__(self, other: object) -> bool:
if isinstance(other, str):
if not isinstance(other, AccessLevel):
other = AccessLevel(other)
if isinstance(other, AccessLevel):
return _accesslevel_order[self.value] <= _accesslevel_order[other.value]
raise NotImplementedError
return _accesslevel_order[self.value] <= _accesslevel_order[other.value]

def __eq__(self, other: object) -> bool:
if isinstance(other, str):
if not isinstance(other, AccessLevel):
other = AccessLevel(other)
if isinstance(other, AccessLevel):
return _accesslevel_order[self.value] == _accesslevel_order[other.value]
raise NotImplementedError
return _accesslevel_order[self.value] == _accesslevel_order[other.value]

def __str__(self) -> str:
return self.value
30 changes: 18 additions & 12 deletions src/qslib/qs_is_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import io
from dataclasses import dataclass
from typing import Any, Coroutine, Optional, Protocol
from typing import Any, Coroutine, Literal, Optional, Protocol
import time


Expand Down Expand Up @@ -34,7 +34,7 @@ class ReplyError(IOError):
class SubHandler(Protocol):
def __call__(
self, topic: bytes, message: bytes, timestamp: float | None = None
) -> Coroutine[None, None, None]:
) -> Coroutine[None, None, None]: # pragma: no cover
...


Expand Down Expand Up @@ -75,7 +75,13 @@ def connection_made(self, transport: Any) -> None:
self.transport = transport
self.waiting_commands: list[
tuple[
bytes, None | Future[tuple[bytes, bytes] | tuple[bytes, asyncio.Future]]
bytes,
None
| Future[
tuple[
bytes, None | bytes, None | asyncio.Future[tuple[bytes, bytes]]
]
],
]
] = []
self.buffer = io.BytesIO()
Expand Down Expand Up @@ -107,16 +113,16 @@ async def parse_message(self, ds: bytes) -> None:
if ds.startswith((b"ERRor", b"OK", b"NEXT")):
ms = ds.index(b" ")
r = None
comfut_new = None
if ds.startswith(b"NEXT"):
loop = asyncio.get_running_loop()
comfut_new = loop.create_future()
for i, (commref, comfut) in enumerate(self.waiting_commands):
if ds.startswith(commref, ms + 1):
if comfut is not None:
if ds.startswith(b"NEXT"):
comfut.set_result((ds[:ms], comfut_new))
else:
comfut.set_result((ds[:ms], ds[ms + len(commref) + 2 :]))
comfut.set_result(
(ds[:ms], ds[ms + len(commref) + 2 :], comfut_new)
)
else:
log.info(f"{commref!r} complete: {ds!r}")
r = i
Expand Down Expand Up @@ -183,7 +189,7 @@ async def run_command(
log.debug(f"Running command {comm.decode()}")
loop = asyncio.get_running_loop()

comfut: Future[tuple[bytes, bytes]] = loop.create_future()
comfut = loop.create_future()
if uid:
import random

Expand All @@ -202,17 +208,17 @@ async def run_command(
except asyncio.CancelledError:
raise ConnectionError

state, msg = comfut.result()
state, msg, comnext = comfut.result()
log.debug(f"Received ({state!r}, {msg!r})")

if state == b"NEXT":
if just_ack:
# self.waiting_commands.append((commref, None))
return b""
else:
comnext = msg
await comnext
state, msg = comnext.result()
state, msg, comnext2 = comnext.result()
assert comnext2 is None
log.debug(f"Received ({state!r}, {msg!r})")

if state == b"OK":
Expand All @@ -221,7 +227,7 @@ async def run_command(
raise CommandError(
comm.decode(), commref.decode(), msg.decode().rstrip()
) from None
else:
else: # pragma: no cover
raise CommandError(
comm.decode(), commref.decode(), (state + b" " + msg).decode()
)
10 changes: 5 additions & 5 deletions tests/test_accesslevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def test_access():
assert (AccessLevel(l1) > l2) == (levels.index(l1) > levels.index(l2))
assert (AccessLevel(l1) >= l2) == (levels.index(l1) >= levels.index(l2))
assert (AccessLevel(l1) == l2) == (levels.index(l1) == levels.index(l2))
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
AccessLevel(l1) > invalid # type: ignore
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
AccessLevel(l1) >= invalid # type: ignore
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
AccessLevel(l1) < invalid # type: ignore
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
AccessLevel(l1) <= invalid # type: ignore
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
AccessLevel(l1) == invalid # type: ignore
assert str(AccessLevel(l1)) == l1
23 changes: 23 additions & 0 deletions tests/test_fakeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ async def _fakeserver_runner(sr: asyncio.StreamReader, sw: asyncio.StreamWriter)
await sw.drain()
sw.write(b"MESSage testservermessage ueao\n")
await sw.drain()
sw.write(b"MESSage testservermessage 123456789.021 ueao\n")
await sw.drain()
sw.write(b"OK " + x.group(1) + b"\n")
await sw.drain()

Expand Down Expand Up @@ -159,3 +161,24 @@ async def test_runtitle_not_running():
async with srv:
with Machine("localhost", port=53533) as m:
assert m.current_run_name == None


@pytest.mark.asyncio
async def test_quote():
msg = "<quote>a\nu\n<quote.2>C\n</quote.2>\n </quote>"
srv = await asyncio.start_server(crcb({"TESTQUOTE": msg}), "localhost", 53533)

async with srv:
with Machine("localhost", port=53533) as m:
assert m.run_command("TESTQUOTE") == msg


@pytest.mark.asyncio
async def test_invalid_quote():
msg = "<quote>a\nu\n</quote.2>\n </quote>"
srv = await asyncio.start_server(crcb({"TESTQUOTE": msg}), "localhost", 53533)

async with srv:
with Machine("localhost", port=53533) as m:
with pytest.raises(ConnectionError):
m.run_command("TESTQUOTE")

0 comments on commit 18acac5

Please sign in to comment.