diff --git a/app.py b/app.py index 927a59f..c529b94 100644 --- a/app.py +++ b/app.py @@ -20,7 +20,6 @@ # TODO: get server stats # TODO: figure out if there is a way to simplify some of the queries using triggers and views instead -# TODO: figure out socket authentication # TODO: support markdown in messages? # TODO: maybe a login that is more like a login? (email link or is this too much) # TODO: user roles + admin mode (when you are the first guy in) @@ -263,16 +262,24 @@ def by_id(message_id: int) -> 'ChannelMessageWCtx': class Socket(TsRec): sid: str; mid: int @dataclass -class Command: +class Cmd: cmd: str @staticmethod - def from_json(cmd: str, data: str) -> 'Command': - if cmd == "ping": return PingCommand(cmd=cmd, **json.loads(data)) + def from_json(j: str) -> 'Cmd': + data = json.loads(j) + cmd, data = data.get("cmd"), data.get("d") + if cmd == "ping": return PingCmd(cmd=cmd, **data) raise ValueError(f"Invalid command: {cmd}") @dataclass -class PingCommand(Command): cid: int +class PingCmd(Cmd): + cid: int + + def to_json(self) -> str: return json.dumps({"cmd": self.cmd, "d": {"cid": self.cid}}) + + @staticmethod + def for_cid(cid: int) -> 'PingCmd': return PingCmd(cmd="ping", cid=cid) @dataclass class ChannelPlaceholder: member: Member @@ -598,10 +605,9 @@ def list_channel_messages(req: Request, cid: int): def channel(req: Request, cid: int): is_mobile, m, w, frm_id, msgs_id, channel = user_agents.parse(req.headers.get('User-Agent')).is_mobile, req.scope['m'], req.scope['w'], f"f-{cid}", f"channel-{cid}", channels[cid] channel_name = f"#{channel.name}" if not channel.is_direct else channel_members(where=f"channel={cid} and member!={m.id}")[0].name - ping_cmd = { "command": "ping", "d": { "cid": cid }, "auth": { "mid": m.id } } convo = [ - Div(cls="hidden", hx_trigger=f"load, every {settings.ping_interval_in_seconds}s", hx_vals=f'{json.dumps(ping_cmd)}', **{"ws_send": "true"}), + Div(cls="hidden", hx_trigger=f"load, every {settings.ping_interval_in_seconds}s", hx_vals=PingCmd.for_cid(cid).to_json(), **{"ws_send": "true"}), Div(cls='border-b flex md:px-6 py-2 items-center flex-none', style="position: fixed; width: 100%; background-color: white;" if is_mobile else "")( Div(cls='flex flex-row items-center')( Button(variant="ghost", **{ "data-testid":"show-mobile-sidebar", "onclick": "document.getElementById('mobile-menu').click()"})(I_ARROW_LEFT(cls="h-6 w-6")) if is_mobile else None, @@ -641,6 +647,7 @@ def direct(req: Request, to_m: int): return RedirectResponse(f'/c/{direct_channel.id}', status_code=303) def ws_connect(ws, send): + if not ws.session: raise WebSocketException(400, "Missing session") try: m, sid = members[int(ws.query_params.get("mid"))], str(id(ws)) connections[sid] = send @@ -654,13 +661,16 @@ def ws_disconnect(ws): except NotFoundError: pass connections.pop(sid, None) -async def process_ping(cmd: PingCommand, member: Member, current_channel: Channel): +async def process_ping(cmd: PingCmd, member: Member, current_channel: Channel): ChannelForMember.from_channel_member(channel_members(where=f"channel={cmd.cid} and member={member.id}")[0]).mark_all_as_read() await ws_send_to_member(member.id, ListOfChannelsForMember(member=member, current_channel=current_channel)) @app.ws('/ws', conn=ws_connect, disconn=ws_disconnect) -async def ws(command:str, auth:dict, d: dict, ws, sess): - mid, channel = int(auth['mid']), channels[int(d['cid'])] +async def ws(cmd:str, d: dict, ws): + if not ws.session: raise WebSocketException(400, "Missing session") + + mid, channel = int(ws.session['mid']), channels[int(d['cid'])] + logger.debug(f"socket ID is {str(id(ws))}") try: socket = sockets[str(id(ws))] @@ -669,9 +679,10 @@ async def ws(command:str, auth:dict, d: dict, ws, sess): return logger.debug(f"got socket {socket}") - logger.debug(f"got command {command} with payload {json.dumps(d)}") + logger.debug(f"got command {cmd} with payload {json.dumps(d)}") - cmd = Command.from_json(command, json.dumps(d)) + cmd = Cmd.from_json(json.dumps({ "cmd": cmd, "d": d })) + await { "ping": process_ping }[cmd.cmd](cmd, members[mid], channel) # ================================================================================================================================================================================================================================ @@ -710,8 +721,9 @@ def run_in_thread(self): with TestServer(config=uvicorn.Config("app:app", host="0.0.0.0", port=5002, log_level="info")).run_in_thread(): yield def test_commands(): - cmd = Command.from_json("ping", '{"cid": 1}') - assert isinstance(cmd, PingCommand) and cmd.cid == 1 + cmd = Cmd.from_json('{ "cmd": "ping", "d": {"cid": 1} }') + assert isinstance(cmd, PingCmd) and cmd.cid == 1 + assert PingCmd.for_cid(1).to_json() == '{"cmd": "ping", "d": {"cid": 1}}' def test_healthcheck(client): response = client.get('/healthcheck')