diff --git a/fasthtml/core.py b/fasthtml/core.py index bf9ada68..ea326b24 100644 --- a/fasthtml/core.py +++ b/fasthtml/core.py @@ -228,6 +228,7 @@ def _find_wsp(ws, data, hdrs, arg:str, p:Parameter): if isinstance(anno, type): if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs) if issubclass(anno, Starlette): return ws.scope['app'] + if issubclass(anno, WebSocket): return ws if anno is empty: if arg.lower()=='ws': return ws if arg.lower()=='scope': return dict2obj(ws.scope) @@ -642,6 +643,9 @@ def ws(self:FastHTML, path:str, conn=None, disconn=None, name=None, middleware=N def f(func=noop): return self.wss.append((func, path, conn, disconn, name, middleware)) return f +# %% ../nbs/api/00_core.ipynb +for o in all_meths: setattr(APIRouter, o, partialmethod(APIRouter.__call__, methods=o)) + # %% ../nbs/api/00_core.ipynb def cookie(key: str, value="", max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax",): "Create a 'set-cookie' `HttpHeader`" diff --git a/fasthtml/starlette.py b/fasthtml/starlette.py index deb20795..295be18a 100644 --- a/fasthtml/starlette.py +++ b/fasthtml/starlette.py @@ -20,5 +20,5 @@ from starlette.types import ASGIApp, Receive, Scope, Send from starlette.concurrency import run_in_threadpool from starlette.background import BackgroundTask, BackgroundTasks -from starlette.websockets import WebSocketDisconnect +from starlette.websockets import WebSocketDisconnect, WebSocket diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb index 1cc3a800..d9d0875f 100644 --- a/nbs/api/00_core.ipynb +++ b/nbs/api/00_core.ipynb @@ -786,6 +786,7 @@ " if isinstance(anno, type):\n", " if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs)\n", " if issubclass(anno, Starlette): return ws.scope['app']\n", + " if issubclass(anno, WebSocket): return ws\n", " if anno is empty:\n", " if arg.lower()=='ws': return ws\n", " if arg.lower()=='scope': return dict2obj(ws.scope)\n", @@ -1628,14 +1629,23 @@ "execution_count": null, "id": "72428702", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Message text was: Hi!, from client: Address(host='testclient', port=50000)\n" + ] + } + ], "source": [ "@app.ws(\"/ws\")\n", - "def ws(self, msg:str): return f\"Message text was: {msg}\"\n", + "def ws(self, msg:str, ws:WebSocket): return f\"Message text was: {msg}, from client: {ws.client}\"\n", "with cli.websocket_connect('/ws') as ws:\n", " ws.send_text('{\"msg\":\"Hi!\"}')\n", " data = ws.receive_text()\n", - " assert data == 'Message text was: Hi!'" + "assert 'Message text was: Hi!' in data\n", + "print(data)" ] }, { @@ -2291,13 +2301,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Set to 2024-10-24 04:46:03.100239\n" + "Set to 2024-10-24 07:04:10.173305\n" ] }, { "data": { "text/plain": [ - "'Session time: 2024-10-24 04:46:03.100239'" + "'Session time: 2024-10-24 07:04:10.173305'" ] }, "execution_count": null, @@ -2525,6 +2535,8 @@ "def get(): return 'Hi there'\n", "@ar(\"/hi\")\n", "def post(): return 'Postal'\n", + "@ar\n", + "def ho(): return 'Ho ho'\n", "@ar(\"/hostie\")\n", "def show_host(req): return req.headers['host']\n", "@ar\n", @@ -2555,7 +2567,10 @@ "test_eq(cli.get('/hi').text, 'Hi there')\n", "test_eq(cli.post('/hi').text, 'Postal')\n", "test_eq(cli.get('/hostie').text, 'testserver')\n", - "test_eq(cli.post('/yoyo').text, 'a yoyo')" + "test_eq(cli.post('/yoyo').text, 'a yoyo')\n", + "\n", + "test_eq(cli.get('/ho').text, 'Ho ho')\n", + "test_eq(cli.post('/ho').text, 'Ho ho')" ] }, { @@ -2571,6 +2586,32 @@ " assert data == 'Message text was: Hi!'" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "259a0f53", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "for o in all_meths: setattr(APIRouter, o, partialmethod(APIRouter.__call__, methods=o))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f61d110c", + "metadata": {}, + "outputs": [], + "source": [ + "@ar.get\n", + "def hi2(): return 'Hi there'\n", + "@ar.get(\"/hi3\")\n", + "def _(): return 'Hi there'\n", + "@ar.post(\"/post2\")\n", + "def _(): return 'Postal'" + ] + }, { "cell_type": "markdown", "id": "dfa6f859", @@ -2631,7 +2672,7 @@ { "data": { "text/plain": [ - "'Cookie was set at time 04:46:03.784300'" + "'Cookie was set at time 07:04:14.878397'" ] }, "execution_count": null,