Skip to content

Commit

Permalink
more type check fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
haliphax committed Feb 24, 2023
1 parent 9965e24 commit 1c36846
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
2 changes: 1 addition & 1 deletion xthulu/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, proc, encoding="utf-8"):
self.events: EventQueue = EventQueue(self.sid)
"""Events queue"""

self.env: dict = proc.env
self.env: dict = proc.env.copy()
"""Environment variables"""

# set up logging
Expand Down
34 changes: 23 additions & 11 deletions xthulu/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
class SSHServer(asyncssh.SSHServer):
"xthulu SSH Server"

_username = None
_username: str | None = None

def connection_made(self, conn: asyncssh.SSHServerConnection):
"Connection opened"

self._peername = conn.get_extra_info("peername")
log.info(conn._extra)
self._peername: list[str] = conn.get_extra_info("peername")
self._sid = "{}:{}".format(*self._peername)
EventQueues.q[self._sid] = aio.Queue()
log.info(f"{self._peername[0]} connecting")
Expand Down Expand Up @@ -88,22 +89,31 @@ async def validate_password(self, username: str, password: str):

return True

def session_requested(self):
return super().session_requested()


async def handle_client(proc: asyncssh.SSHServerProcess):
"Client connected"

cx = Context(proc=proc)
await cx._init()

if "LANG" not in cx.proc.env or "UTF-8" not in cx.proc.env["LANG"]:
if "LANG" not in proc.env or "UTF-8" not in proc.env["LANG"]:
cx.encoding = "cp437"

termtype = proc.get_terminal_type()

if "TERM" not in cx.env:
cx.env["TERM"] = termtype

w, h, pw, ph = proc.get_terminal_size()
cx.env["COLUMNS"] = w
cx.env["LINES"] = h
proxy_pipe, subproc_pipe = Pipe()
session_stdin = aio.Queue()
timeout = int(config.get("ssh", {}).get("session", {}).get("timeout", 120))
await cx.user.update(last=datetime.utcnow()).apply()
await cx.user.update(last=datetime.utcnow()).apply() # type: ignore

async def input_loop():
"Catch exceptions on stdin and convert to EventData"
Expand All @@ -128,6 +138,8 @@ async def input_loop():
return

except asyncssh.misc.TerminalSizeChanged as sz:
cx.env["COLUMNS"] = sz.width
cx.env["LINES"] = sz.height
cx.term._width = sz.width
cx.term._height = sz.height
cx.term._pixel_width = sz.pixwidth
Expand Down Expand Up @@ -181,13 +193,13 @@ async def start_server():
"Run init tasks and throw SSH server into asyncio event loop"

await db.set_bind(config["db"]["bind"])
log.info(
"SSH listening on " f"{config['ssh']['host']}:{config['ssh']['port']}"
)
await asyncssh.create_server(
SSHServer,
config["ssh"]["host"],
int(config["ssh"]["port"]),
host: str = config["ssh"]["host"]
port = int(config["ssh"]["port"])
log.info(f"SSH listening on {host}:{port}")
await asyncssh.listen(
host=host,
port=port,
server_factory=SSHServer,
server_host_keys=config["ssh"]["host_keys"],
process_factory=handle_client,
encoding=None,
Expand Down
38 changes: 16 additions & 22 deletions xthulu/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://github.com/jquast

# type checking
from typing import Callable, Optional, Protocol
from typing import Any, Callable, Optional

# stdlib
import asyncio as aio
Expand Down Expand Up @@ -73,37 +73,31 @@ def __call__(self, *args, **kwargs):
return self.pipe_master.recv()


class IntOrNoneReturnsStr(Protocol):
def __call__(self, value: int = 1) -> str:
...


class ProxyTerminal(object):
class ProxyTerminal:
_kbdbuf = []

# context manager attribs
_ctxattrs = (
"location",
"keypad",
"raw",
"cbreak",
"hidden_cursor",
"fullscreen",
"hidden_cursor",
"keypad",
"location",
"raw",
)

# type hints
clear_eol: Callable[[], str]
move_down: IntOrNoneReturnsStr
move_left: IntOrNoneReturnsStr
move_right: IntOrNoneReturnsStr
move_up: IntOrNoneReturnsStr
move_x: IntOrNoneReturnsStr
move_y: IntOrNoneReturnsStr
# their type hints
cbreak: contextlib._GeneratorContextManager[Any]
fullscreen: contextlib._GeneratorContextManager[Any]
hidden_cursor: contextlib._GeneratorContextManager[Any]
keypad: contextlib._GeneratorContextManager[Any]
location: contextlib._GeneratorContextManager[Any]
raw: contextlib._GeneratorContextManager[Any]

def __init__(
self,
stdin: aio.Queue[bytes],
stdout: StringIO,
stdout: Any,
encoding: str,
pipe_master: Connection,
width: int = 0,
Expand All @@ -119,7 +113,7 @@ def __init__(
self._pixel_width = pixel_width
self._pixel_height = pixel_height

def __getattr__(self, attr: str):
def __getattr__(self, attr: str) -> Callable[..., str]:
@contextlib.contextmanager
def proxy_contextmanager(*args, **kwargs):
# we send special '!CTX' header, which means we
Expand Down Expand Up @@ -153,7 +147,7 @@ def proxy_contextmanager(*args, **kwargs):
self.stdout.write(exit_side_effect)

if attr in self._ctxattrs:
return proxy_contextmanager
return proxy_contextmanager # type: ignore

blessed_attr = getattr(BlessedTerminal, attr, None)

Expand Down

0 comments on commit 1c36846

Please sign in to comment.