diff --git a/fasthtml/core.py b/fasthtml/core.py
index bf9ada68..ea326b24 100644
--- a/fasthtml/core.py
+++ b/fasthtml/core.py
@@ -228,6 +228,7 @@ def _find_wsp(ws, data, hdrs, arg:str, p:Parameter):
if isinstance(anno, type):
if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs)
if issubclass(anno, Starlette): return ws.scope['app']
+ if issubclass(anno, WebSocket): return ws
if anno is empty:
if arg.lower()=='ws': return ws
if arg.lower()=='scope': return dict2obj(ws.scope)
@@ -642,6 +643,9 @@ def ws(self:FastHTML, path:str, conn=None, disconn=None, name=None, middleware=N
def f(func=noop): return self.wss.append((func, path, conn, disconn, name, middleware))
return f
+# %% ../nbs/api/00_core.ipynb
+for o in all_meths: setattr(APIRouter, o, partialmethod(APIRouter.__call__, methods=o))
+
# %% ../nbs/api/00_core.ipynb
def cookie(key: str, value="", max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax",):
"Create a 'set-cookie' `HttpHeader`"
diff --git a/fasthtml/starlette.py b/fasthtml/starlette.py
index deb20795..295be18a 100644
--- a/fasthtml/starlette.py
+++ b/fasthtml/starlette.py
@@ -20,5 +20,5 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.concurrency import run_in_threadpool
from starlette.background import BackgroundTask, BackgroundTasks
-from starlette.websockets import WebSocketDisconnect
+from starlette.websockets import WebSocketDisconnect, WebSocket
diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb
index 1cc3a800..d9d0875f 100644
--- a/nbs/api/00_core.ipynb
+++ b/nbs/api/00_core.ipynb
@@ -786,6 +786,7 @@
" if isinstance(anno, type):\n",
" if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs)\n",
" if issubclass(anno, Starlette): return ws.scope['app']\n",
+ " if issubclass(anno, WebSocket): return ws\n",
" if anno is empty:\n",
" if arg.lower()=='ws': return ws\n",
" if arg.lower()=='scope': return dict2obj(ws.scope)\n",
@@ -1628,14 +1629,23 @@
"execution_count": null,
"id": "72428702",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Message text was: Hi!, from client: Address(host='testclient', port=50000)\n"
+ ]
+ }
+ ],
"source": [
"@app.ws(\"/ws\")\n",
- "def ws(self, msg:str): return f\"Message text was: {msg}\"\n",
+ "def ws(self, msg:str, ws:WebSocket): return f\"Message text was: {msg}, from client: {ws.client}\"\n",
"with cli.websocket_connect('/ws') as ws:\n",
" ws.send_text('{\"msg\":\"Hi!\"}')\n",
" data = ws.receive_text()\n",
- " assert data == 'Message text was: Hi!'"
+ "assert 'Message text was: Hi!' in data\n",
+ "print(data)"
]
},
{
@@ -2291,13 +2301,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Set to 2024-10-24 04:46:03.100239\n"
+ "Set to 2024-10-24 07:04:10.173305\n"
]
},
{
"data": {
"text/plain": [
- "'Session time: 2024-10-24 04:46:03.100239'"
+ "'Session time: 2024-10-24 07:04:10.173305'"
]
},
"execution_count": null,
@@ -2525,6 +2535,8 @@
"def get(): return 'Hi there'\n",
"@ar(\"/hi\")\n",
"def post(): return 'Postal'\n",
+ "@ar\n",
+ "def ho(): return 'Ho ho'\n",
"@ar(\"/hostie\")\n",
"def show_host(req): return req.headers['host']\n",
"@ar\n",
@@ -2555,7 +2567,10 @@
"test_eq(cli.get('/hi').text, 'Hi there')\n",
"test_eq(cli.post('/hi').text, 'Postal')\n",
"test_eq(cli.get('/hostie').text, 'testserver')\n",
- "test_eq(cli.post('/yoyo').text, 'a yoyo')"
+ "test_eq(cli.post('/yoyo').text, 'a yoyo')\n",
+ "\n",
+ "test_eq(cli.get('/ho').text, 'Ho ho')\n",
+ "test_eq(cli.post('/ho').text, 'Ho ho')"
]
},
{
@@ -2571,6 +2586,32 @@
" assert data == 'Message text was: Hi!'"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "259a0f53",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "for o in all_meths: setattr(APIRouter, o, partialmethod(APIRouter.__call__, methods=o))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f61d110c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@ar.get\n",
+ "def hi2(): return 'Hi there'\n",
+ "@ar.get(\"/hi3\")\n",
+ "def _(): return 'Hi there'\n",
+ "@ar.post(\"/post2\")\n",
+ "def _(): return 'Postal'"
+ ]
+ },
{
"cell_type": "markdown",
"id": "dfa6f859",
@@ -2631,7 +2672,7 @@
{
"data": {
"text/plain": [
- "'Cookie was set at time 04:46:03.784300'"
+ "'Cookie was set at time 07:04:14.878397'"
]
},
"execution_count": null,