diff --git a/fasthtml/_modidx.py b/fasthtml/_modidx.py
index 45ae2d98..62189411 100644
--- a/fasthtml/_modidx.py
+++ b/fasthtml/_modidx.py
@@ -160,7 +160,8 @@
'fasthtml/oauth.py'),
'fasthtml.oauth._AppClient.retr_id': ('api/oauth.html#_appclient.retr_id', 'fasthtml/oauth.py'),
'fasthtml.oauth._AppClient.retr_info': ('api/oauth.html#_appclient.retr_info', 'fasthtml/oauth.py'),
- 'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py')},
+ 'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py'),
+ 'fasthtml.oauth.url_match': ('api/oauth.html#url_match', 'fasthtml/oauth.py')},
'fasthtml.pico': { 'fasthtml.pico.Card': ('api/pico.html#card', 'fasthtml/pico.py'),
'fasthtml.pico.Container': ('api/pico.html#container', 'fasthtml/pico.py'),
'fasthtml.pico.DialogX': ('api/pico.html#dialogx', 'fasthtml/pico.py'),
diff --git a/fasthtml/oauth.py b/fasthtml/oauth.py
index adaa336c..8b16f1ae 100644
--- a/fasthtml/oauth.py
+++ b/fasthtml/oauth.py
@@ -3,7 +3,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/08_oauth.ipynb.
# %% auto 0
-__all__ = ['GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url', 'OAuth']
+__all__ = ['http_patterns', 'GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url',
+ 'url_match', 'OAuth']
# %% ../nbs/api/08_oauth.ipynb
from .common import *
@@ -95,9 +96,9 @@ def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None):
return self.prepare_request_uri(self.base_url, redirect_uri, scope, state=state)
# %% ../nbs/api/08_oauth.ipynb
-def redir_url(request, redir_path):
+def redir_url(request, redir_path, scheme='https'):
"Get the redir url for the host in `request`"
- return f"{request.url.scheme}://{request.url.netloc}{redir_path}"
+ return f"{scheme}://{request.url.netloc}{redir_path}"
# %% ../nbs/api/08_oauth.ipynb
@patch
@@ -130,12 +131,16 @@ def retr_id(self:_AppClient, code, redirect_uri):
"Call `retr_info` and then return id/subscriber value"
return self.retr_info(code, redirect_uri)[self.id_key]
+# %% ../nbs/api/08_oauth.ipynb
+http_patterns = (r'^(localhost|127\.0\.0\.1)(:\d+)?$',)
+def url_match(url, patterns=http_patterns):
+ return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns)
+
# %% ../nbs/api/08_oauth.ipynb
class OAuth:
- def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):
+ def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
if not skip: skip = [redir_path,login_path]
store_attr()
-
def before(req, session):
auth = req.scope['auth'] = session.get('auth')
if not auth: return RedirectResponse(self.login_path, status_code=303)
@@ -147,8 +152,8 @@ def before(req, session):
@app.get(redir_path)
def redirect(code:str, req, session, state:str=None):
if not code: return "No code provided!"
- base_url = f"{req.url.scheme}://{req.url.netloc}"
- print(base_url)
+ scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'
+ base_url = f"{scheme}://{req.url.netloc}"
info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))
if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)
session['auth'] = cli.token['access_token']
@@ -159,7 +164,9 @@ def logout(session):
session.pop('auth', None)
return self.logout(session)
- def redir_url(self, req): return redir_url(req, self.redir_path)
+ def redir_url(self, req):
+ scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'
+ return redir_url(req, self.redir_path, scheme)
def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)
def login(self, info, state): raise NotImplementedError()
diff --git a/nbs/api/08_oauth.ipynb b/nbs/api/08_oauth.ipynb
index d28d533e..7b8c3b17 100644
--- a/nbs/api/08_oauth.ipynb
+++ b/nbs/api/08_oauth.ipynb
@@ -247,9 +247,9 @@
"outputs": [],
"source": [
"#| export\n",
- "def redir_url(request, redir_path):\n",
+ "def redir_url(request, redir_path, scheme='https'):\n",
" \"Get the redir url for the host in `request`\"\n",
- " return f\"{request.url.scheme}://{request.url.netloc}{redir_path}\""
+ " return f\"{scheme}://{request.url.netloc}{redir_path}\""
]
},
{
@@ -394,6 +394,19 @@
"After either of these calls, you can also access the access token (used to revoke access, for example) with `client.token[\"access_token\"]`."
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7b96e009",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "http_patterns = (r'^(localhost|127\\.0\\.0\\.1)(:\\d+)?$',)\n",
+ "def url_match(url, patterns=http_patterns):\n",
+ " return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -403,10 +416,9 @@
"source": [
"#| export\n",
"class OAuth:\n",
- " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):\n",
+ " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n",
" if not skip: skip = [redir_path,login_path]\n",
" store_attr()\n",
- "\n",
" def before(req, session):\n",
" auth = req.scope['auth'] = session.get('auth')\n",
" if not auth: return RedirectResponse(self.login_path, status_code=303)\n",
@@ -418,8 +430,8 @@
" @app.get(redir_path)\n",
" def redirect(code:str, req, session, state:str=None):\n",
" if not code: return \"No code provided!\"\n",
- " base_url = f\"{req.url.scheme}://{req.url.netloc}\"\n",
- " print(base_url)\n",
+ " scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n",
+ " base_url = f\"{scheme}://{req.url.netloc}\"\n",
" info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))\n",
" if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)\n",
" session['auth'] = cli.token['access_token']\n",
@@ -430,7 +442,9 @@
" session.pop('auth', None)\n",
" return self.logout(session)\n",
"\n",
- " def redir_url(self, req): return redir_url(req, self.redir_path)\n",
+ " def redir_url(self, req): \n",
+ " scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n",
+ " return redir_url(req, self.redir_path, scheme)\n",
" def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)\n",
"\n",
" def login(self, info, state): raise NotImplementedError()\n",