Skip to content

Commit

Permalink
Merge pull request #501 from AnswerDotAI/oauth_https
Browse files Browse the repository at this point in the history
bug fix
  • Loading branch information
jph00 authored Oct 8, 2024
2 parents fd0aba1 + 3cac144 commit 29d891b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
3 changes: 2 additions & 1 deletion fasthtml/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@
'fasthtml/oauth.py'),
'fasthtml.oauth._AppClient.retr_id': ('api/oauth.html#_appclient.retr_id', 'fasthtml/oauth.py'),
'fasthtml.oauth._AppClient.retr_info': ('api/oauth.html#_appclient.retr_info', 'fasthtml/oauth.py'),
'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py')},
'fasthtml.oauth.redir_url': ('api/oauth.html#redir_url', 'fasthtml/oauth.py'),
'fasthtml.oauth.url_match': ('api/oauth.html#url_match', 'fasthtml/oauth.py')},
'fasthtml.pico': { 'fasthtml.pico.Card': ('api/pico.html#card', 'fasthtml/pico.py'),
'fasthtml.pico.Container': ('api/pico.html#container', 'fasthtml/pico.py'),
'fasthtml.pico.DialogX': ('api/pico.html#dialogx', 'fasthtml/pico.py'),
Expand Down
23 changes: 15 additions & 8 deletions fasthtml/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/08_oauth.ipynb.

# %% auto 0
__all__ = ['GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url', 'OAuth']
__all__ = ['http_patterns', 'GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'redir_url',
'url_match', 'OAuth']

# %% ../nbs/api/08_oauth.ipynb
from .common import *
Expand Down Expand Up @@ -95,9 +96,9 @@ def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None):
return self.prepare_request_uri(self.base_url, redirect_uri, scope, state=state)

# %% ../nbs/api/08_oauth.ipynb
def redir_url(request, redir_path):
def redir_url(request, redir_path, scheme='https'):
"Get the redir url for the host in `request`"
return f"{request.url.scheme}://{request.url.netloc}{redir_path}"
return f"{scheme}://{request.url.netloc}{redir_path}"

# %% ../nbs/api/08_oauth.ipynb
@patch
Expand Down Expand Up @@ -130,12 +131,16 @@ def retr_id(self:_AppClient, code, redirect_uri):
"Call `retr_info` and then return id/subscriber value"
return self.retr_info(code, redirect_uri)[self.id_key]

# %% ../nbs/api/08_oauth.ipynb
http_patterns = (r'^(localhost|127\.0\.0\.1)(:\d+)?$',)
def url_match(url, patterns=http_patterns):
return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns)

# %% ../nbs/api/08_oauth.ipynb
class OAuth:
def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):
def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
if not skip: skip = [redir_path,login_path]
store_attr()

def before(req, session):
auth = req.scope['auth'] = session.get('auth')
if not auth: return RedirectResponse(self.login_path, status_code=303)
Expand All @@ -147,8 +152,8 @@ def before(req, session):
@app.get(redir_path)
def redirect(code:str, req, session, state:str=None):
if not code: return "No code provided!"
base_url = f"{req.url.scheme}://{req.url.netloc}"
print(base_url)
scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'
base_url = f"{scheme}://{req.url.netloc}"
info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))
if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)
session['auth'] = cli.token['access_token']
Expand All @@ -159,7 +164,9 @@ def logout(session):
session.pop('auth', None)
return self.logout(session)

def redir_url(self, req): return redir_url(req, self.redir_path)
def redir_url(self, req):
scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'
return redir_url(req, self.redir_path, scheme)
def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)

def login(self, info, state): raise NotImplementedError()
Expand Down
28 changes: 21 additions & 7 deletions nbs/api/08_oauth.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@
"outputs": [],
"source": [
"#| export\n",
"def redir_url(request, redir_path):\n",
"def redir_url(request, redir_path, scheme='https'):\n",
" \"Get the redir url for the host in `request`\"\n",
" return f\"{request.url.scheme}://{request.url.netloc}{redir_path}\""
" return f\"{scheme}://{request.url.netloc}{redir_path}\""
]
},
{
Expand Down Expand Up @@ -394,6 +394,19 @@
"After either of these calls, you can also access the access token (used to revoke access, for example) with `client.token[\"access_token\"]`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b96e009",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"http_patterns = (r'^(localhost|127\\.0\\.0\\.1)(:\\d+)?$',)\n",
"def url_match(url, patterns=http_patterns):\n",
" return any(re.match(pattern, url.netloc.split(':')[0]) for pattern in patterns)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -403,10 +416,9 @@
"source": [
"#| export\n",
"class OAuth:\n",
" def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):\n",
" def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):\n",
" if not skip: skip = [redir_path,login_path]\n",
" store_attr()\n",
"\n",
" def before(req, session):\n",
" auth = req.scope['auth'] = session.get('auth')\n",
" if not auth: return RedirectResponse(self.login_path, status_code=303)\n",
Expand All @@ -418,8 +430,8 @@
" @app.get(redir_path)\n",
" def redirect(code:str, req, session, state:str=None):\n",
" if not code: return \"No code provided!\"\n",
" base_url = f\"{req.url.scheme}://{req.url.netloc}\"\n",
" print(base_url)\n",
" scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n",
" base_url = f\"{scheme}://{req.url.netloc}\"\n",
" info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))\n",
" if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)\n",
" session['auth'] = cli.token['access_token']\n",
Expand All @@ -430,7 +442,9 @@
" session.pop('auth', None)\n",
" return self.logout(session)\n",
"\n",
" def redir_url(self, req): return redir_url(req, self.redir_path)\n",
" def redir_url(self, req): \n",
" scheme = 'http' if url_match(req.url,self.http_patterns) or not self.https else 'https'\n",
" return redir_url(req, self.redir_path, scheme)\n",
" def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)\n",
"\n",
" def login(self, info, state): raise NotImplementedError()\n",
Expand Down

0 comments on commit 29d891b

Please sign in to comment.