Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support redirection through webbrowser for OAuth authentication #162

Merged
merged 2 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,14 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html)

### OAuth2 Authentication

- `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html).
- A callback to handle the redirect url can be provided via param `redirect_auth_url_handler`, by default it just outputs the redirect url to stdout.

* DBAPI
A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.

The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.

- DBAPI
hashhar marked this conversation as resolved.
Show resolved Hide resolved

```python
from trino.dbapi import connect
Expand All @@ -185,7 +188,7 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h
)
```

* SQLAlchemy
- SQLAlchemy

```python
from sqlalchemy import create_engine
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]
external_authentication_token_cache_require = ["keyring"]

# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
all_require = kerberos_require + sqlalchemy_require
hashhar marked this conversation as resolved.
Show resolved Hide resolved

tests_require = all_require + [
Expand Down Expand Up @@ -80,6 +82,7 @@
"kerberos": kerberos_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
"external-authentication-token-cache": external_authentication_token_cache_require,
},
entry_points={
"sqlalchemy.dialects": [
Expand Down
128 changes: 128 additions & 0 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
import uuid
from collections import namedtuple

import httpretty

from trino import constants

SERVER_ADDRESS = "https://coordinator"
REDIRECT_PATH = "oauth2/initiate"
TOKEN_PATH = "oauth2/token"
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"


class RedirectHandler:
def __init__(self):
self.redirect_server = ""

def __call__(self, url):
self.redirect_server += url


class PostStatementCallback:
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
self.redirect_server = redirect_server
self.token_server = token_server
self.tokens = tokens
self.sample_post_response_data = sample_post_response_data

def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]


class GetTokenCallback:
def __init__(self, token_server, token, attempts=1):
self.token_server = token_server
self.token = token
self.attempts = attempts

def __call__(self, request, uri, response_headers):
self.attempts -= 1
if self.attempts < 0:
return [404, response_headers, "{}"]
if self.attempts == 0:
return [200, response_headers, f'{{"token": "{self.token}"}}']
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']


def _get_token_requests(challenge_id):
return list(filter(
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
httpretty.latest_requests()))


def _post_statement_requests():
return list(filter(
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
httpretty.latest_requests()))


class MultithreadedTokenServer:
Challenge = namedtuple('Challenge', ['token', 'attempts'])

def __init__(self, sample_post_response_data, attempts=1):
self.tokens = set()
self.challenges = {}
self.sample_post_response_data = sample_post_response_data
self.attempts = attempts

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=self.post_statement_callback)

# bind get token
httpretty.register_uri(
method=httpretty.GET,
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
body=self.get_token_callback)

# noinspection PyUnusedLocal
def post_statement_callback(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")

if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]

challenge_id = str(uuid.uuid4())
token = str(uuid.uuid4())
self.tokens.add(token)
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
'Basic realm': '"Trino"'}, ""]

# noinspection PyUnusedLocal
def get_token_callback(self, request, uri, response_headers):
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
challenge = self.challenges[challenge_id]
challenge = challenge._replace(attempts=challenge.attempts - 1)
self.challenges[challenge_id] = challenge
if challenge.attempts < 0:
return [404, response_headers, "{}"]
if challenge.attempts == 0:
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
return [200, response_headers, f'{{"nextUri": "{uri}"}}']
124 changes: 8 additions & 116 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import threading
import time
import uuid
from collections import namedtuple
from unittest import mock
from urllib.parse import urlparse

Expand All @@ -25,6 +23,9 @@
from requests_kerberos.exceptions import KerberosExchangeError

import trino.exceptions
from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \
MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \
SERVER_ADDRESS
from trino import constants
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
from trino.client import TrinoQuery, TrinoRequest, TrinoResult
Expand Down Expand Up @@ -259,52 +260,6 @@ def long_call(request, uri, headers):
httpretty.reset()


SERVER_ADDRESS = "https://coordinator"
REDIRECT_PATH = "oauth2/initiate"
TOKEN_PATH = "oauth2/token"
REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}"
TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}"


class RedirectHandler:
def __init__(self):
self.redirect_server = ""

def __call__(self, url):
self.redirect_server += url


class PostStatementCallback:
def __init__(self, redirect_server, token_server, tokens, sample_post_response_data):
self.redirect_server = redirect_server
self.token_server = token_server
self.tokens = tokens
self.sample_post_response_data = sample_post_response_data

def __call__(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"',
'Basic realm': '"Trino"'}, ""]


class GetTokenCallback:
def __init__(self, token_server, token, attempts=1):
self.token_server = token_server
self.token = token
self.attempts = attempts

def __call__(self, request, uri, response_headers):
self.attempts -= 1
if self.attempts < 0:
return [404, response_headers, "{}"]
if self.attempts == 0:
return [200, response_headers, f'{{"token": "{self.token}"}}']
return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}']


@pytest.mark.parametrize("attempts", [1, 3, 5])
@httprettified
def test_oauth2_authentication_flow(attempts, sample_post_response_data):
Expand Down Expand Up @@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon
assert len(_get_token_requests(challenge_id)) == 1


class MultithreadedTokenServer:
Challenge = namedtuple('Challenge', ['token', 'attempts'])

def __init__(self, sample_post_response_data, attempts=1):
self.tokens = set()
self.challenges = {}
self.sample_post_response_data = sample_post_response_data
self.attempts = attempts

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=self.post_statement_callback)

# bind get token
httpretty.register_uri(
method=httpretty.GET,
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
body=self.get_token_callback)

# noinspection PyUnusedLocal
def post_statement_callback(self, request, uri, response_headers):
authorization = request.headers.get("Authorization")

if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]

challenge_id = str(uuid.uuid4())
token = str(uuid.uuid4())
self.tokens.add(token)
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
'Basic realm': '"Trino"'}, ""]

# noinspection PyUnusedLocal
def get_token_callback(self, request, uri, response_headers):
challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "")
challenge = self.challenges[challenge_id]
challenge = challenge._replace(attempts=challenge.attempts - 1)
self.challenges[challenge_id] = challenge
if challenge.attempts < 0:
return [404, response_headers, "{}"]
if challenge.attempts == 0:
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
return [200, response_headers, f'{{"nextUri": "{uri}"}}']


@httprettified
def test_multithreaded_oauth2_authentication_flow(sample_post_response_data):
redirect_handler = RedirectHandler()
Expand Down Expand Up @@ -598,31 +502,19 @@ def run(self) -> None:
for thread in threads:
thread.join()

# should issue only 3 tokens and each thread should get one
assert len(token_server.tokens) == 3
# should issue only 1 token and each thread should reuse it
assert len(token_server.tokens) == 1
for thread in threads:
assert thread.token in token_server.tokens

# should start only 3 challenges and every token should be obtained
assert len(token_server.challenges.keys()) == 3
# should start only 1 challenge
assert len(token_server.challenges.keys()) == 1
for challenge_id, challenge in token_server.challenges.items():
assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server
assert challenge.attempts == 0
assert len(_get_token_requests(challenge_id)) == 1
# 3 threads * (10 POST /statement each + 1 replied request by authentication)
assert len(_post_statement_requests()) == 33


def _get_token_requests(challenge_id):
return list(filter(
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
httpretty.latest_requests()))


def _post_statement_requests():
return list(filter(
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
httpretty.latest_requests()))
assert len(_post_statement_requests()) == 31


@mock.patch("trino.client.TrinoRequest.http")
Expand Down
Loading