diff --git a/fasthtml/_modidx.py b/fasthtml/_modidx.py index 45ae2d98..62189411 100644 --- a/fasthtml/_modidx.py +++ b/fasthtml/_modidx.py @@ -160,7 +160,8 @@ 'fasthtml/oauth.py'), 'fasthtml.oauth._AppClient.retr_id': ('api/oauth.html#_appclient.retr_id', 'fasthtml/oauth.py'), 'fasthtml.oauth._AppClient.retr_info': ('api/oauth.html#_appclient.retr_info', 'fasthtml/oauth.py'), - 'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py')}, + 'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py'), + 'fasthtml.oauth.url_match': ('api/oauth.html#url_match', 'fasthtml/oauth.py')}, 'fasthtml.pico': { 'fasthtml.pico.Card': ('api/pico.html#card', 'fasthtml/pico.py'), 'fasthtml.pico.Container': ('api/pico.html#container', 'fasthtml/pico.py'), 'fasthtml.pico.DialogX': ('api/pico.html#dialogx', 'fasthtml/pico.py'), diff --git a/fasthtml/oauth.py b/fasthtml/oauth.py index adaa336c..8b16f1ae 100644 --- a/fasthtml/oauth.py +++ b/fasthtml/oauth.py @@ -3,7 +3,8 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/08_oauth.ipynb. # %% auto 0 -__all__ = ['GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url', 'OAuth'] +__all__ = ['http_patterns', 'GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url', + 'url_match', 'OAuth'] # %% ../nbs/api/08_oauth.ipynb from .common import * @@ -95,9 +96,9 @@ def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None): return self.prepare_request_uri(self.base_url, redirect_uri, scope, state=state) # %% ../nbs/api/08_oauth.ipynb -def redir_url(request, redir_path): +def redir_url(request, redir_path, scheme='https'): "Get the redir url for the host in `request`" - return f"{request.url.scheme}://{request.url.netloc}{redir_path}" + return f"{scheme}://{request.url.netloc}{redir_path}" # %% ../nbs/api/08_oauth.ipynb @patch @@ -130,12 +131,16 @@ def retr_id(self:_AppClient, code, redirect_uri): "Call `retr_info` and then return id/subscriber value" return self.retr_info(code, redirect_uri)[self.id_key] +# %% ../nbs/api/08_oauth.ipynb +http_patterns = (r'^(localhost|127\.0\.0\.1)(:\d+)?$',) +def url_match(url, patterns=http_patterns): + return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns) + # %% ../nbs/api/08_oauth.ipynb class OAuth: - def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'): + def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns): if not skip: skip = [redir_path,login_path] store_attr() - def before(req, session): auth = req.scope['auth'] = session.get('auth') if not auth: return RedirectResponse(self.login_path, status_code=303) @@ -147,8 +152,8 @@ def before(req, session): @app.get(redir_path) def redirect(code:str, req, session, state:str=None): if not code: return "No code provided!" - base_url = f"{req.url.scheme}://{req.url.netloc}" - print(base_url) + scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https' + base_url = f"{scheme}://{req.url.netloc}" info = AttrDictDefault(cli.retr_info(code, base_url+redir_path)) if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303) session['auth'] = cli.token['access_token'] @@ -159,7 +164,9 @@ def logout(session): session.pop('auth', None) return self.logout(session) - def redir_url(self, req): return redir_url(req, self.redir_path) + def redir_url(self, req): + scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https' + return redir_url(req, self.redir_path, scheme) def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state) def login(self, info, state): raise NotImplementedError() diff --git a/nbs/api/08_oauth.ipynb b/nbs/api/08_oauth.ipynb index d28d533e..7b8c3b17 100644 --- a/nbs/api/08_oauth.ipynb +++ b/nbs/api/08_oauth.ipynb @@ -247,9 +247,9 @@ "outputs": [], "source": [ "#| export\n", - "def redir_url(request, redir_path):\n", + "def redir_url(request, redir_path, scheme='https'):\n", " \"Get the redir url for the host in `request`\"\n", - " return f\"{request.url.scheme}://{request.url.netloc}{redir_path}\"" + " return f\"{scheme}://{request.url.netloc}{redir_path}\"" ] }, { @@ -394,6 +394,19 @@ "After either of these calls, you can also access the access token (used to revoke access, for example) with `client.token[\"access_token\"]`." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b96e009", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "http_patterns = (r'^(localhost|127\\.0\\.0\\.1)(:\\d+)?$',)\n", + "def url_match(url, patterns=http_patterns):\n", + " return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -403,10 +416,9 @@ "source": [ "#| export\n", "class OAuth:\n", - " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):\n", + " def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n", " if not skip: skip = [redir_path,login_path]\n", " store_attr()\n", - "\n", " def before(req, session):\n", " auth = req.scope['auth'] = session.get('auth')\n", " if not auth: return RedirectResponse(self.login_path, status_code=303)\n", @@ -418,8 +430,8 @@ " @app.get(redir_path)\n", " def redirect(code:str, req, session, state:str=None):\n", " if not code: return \"No code provided!\"\n", - " base_url = f\"{req.url.scheme}://{req.url.netloc}\"\n", - " print(base_url)\n", + " scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n", + " base_url = f\"{scheme}://{req.url.netloc}\"\n", " info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))\n", " if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)\n", " session['auth'] = cli.token['access_token']\n", @@ -430,7 +442,9 @@ " session.pop('auth', None)\n", " return self.logout(session)\n", "\n", - " def redir_url(self, req): return redir_url(req, self.redir_path)\n", + " def redir_url(self, req): \n", + " scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n", + " return redir_url(req, self.redir_path, scheme)\n", " def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)\n", "\n", " def login(self, info, state): raise NotImplementedError()\n",