diff --git a/README.md b/README.md index 09a4c5e0..12641a73 100644 --- a/README.md +++ b/README.md @@ -167,11 +167,14 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html) ### OAuth2 Authentication -- `OAuth2Authentication` class can be used to connect to a Trino cluster configured with +The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html). -- A callback to handle the redirect url can be provided via param `redirect_auth_url_handler`, by default it just outputs the redirect url to stdout. -* DBAPI +A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class. + +The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`. + +- DBAPI ```python from trino.dbapi import connect @@ -185,7 +188,7 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h ) ``` -* SQLAlchemy +- SQLAlchemy ```python from sqlalchemy import create_engine diff --git a/setup.py b/setup.py index c5f4ef20..2695a530 100755 --- a/setup.py +++ b/setup.py @@ -27,7 +27,9 @@ kerberos_require = ["requests_kerberos"] sqlalchemy_require = ["sqlalchemy~=1.3"] +external_authentication_token_cache_require = ["keyring"] +# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring. all_require = kerberos_require + sqlalchemy_require tests_require = all_require + [ @@ -80,6 +82,7 @@ "kerberos": kerberos_require, "sqlalchemy": sqlalchemy_require, "tests": tests_require, + "external-authentication-token-cache": external_authentication_token_cache_require, }, entry_points={ "sqlalchemy.dialects": [ diff --git a/tests/unit/oauth_test_utils.py b/tests/unit/oauth_test_utils.py new file mode 100644 index 00000000..77b891ce --- /dev/null +++ b/tests/unit/oauth_test_utils.py @@ -0,0 +1,128 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import uuid +from collections import namedtuple + +import httpretty + +from trino import constants + +SERVER_ADDRESS = "https://coordinator" +REDIRECT_PATH = "oauth2/initiate" +TOKEN_PATH = "oauth2/token" +REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}" +TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}" + + +class RedirectHandler: + def __init__(self): + self.redirect_server = "" + + def __call__(self, url): + self.redirect_server += url + + +class PostStatementCallback: + def __init__(self, redirect_server, token_server, tokens, sample_post_response_data): + self.redirect_server = redirect_server + self.token_server = token_server + self.tokens = tokens + self.sample_post_response_data = sample_post_response_data + + def __call__(self, request, uri, response_headers): + authorization = request.headers.get("Authorization") + if authorization and authorization.replace("Bearer ", "") in self.tokens: + return [200, response_headers, json.dumps(self.sample_post_response_data)] + return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' + f'x_token_server="{self.token_server}"', + 'Basic realm': '"Trino"'}, ""] + + +class GetTokenCallback: + def __init__(self, token_server, token, attempts=1): + self.token_server = token_server + self.token = token + self.attempts = attempts + + def __call__(self, request, uri, response_headers): + self.attempts -= 1 + if self.attempts < 0: + return [404, response_headers, "{}"] + if self.attempts == 0: + return [200, response_headers, f'{{"token": "{self.token}"}}'] + return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}'] + + +def _get_token_requests(challenge_id): + return list(filter( + lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", + httpretty.latest_requests())) + + +def _post_statement_requests(): + return list(filter( + lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, + httpretty.latest_requests())) + + +class MultithreadedTokenServer: + Challenge = namedtuple('Challenge', ['token', 'attempts']) + + def __init__(self, sample_post_response_data, attempts=1): + self.tokens = set() + self.challenges = {} + self.sample_post_response_data = sample_post_response_data + self.attempts = attempts + + # bind post statement + httpretty.register_uri( + method=httpretty.POST, + uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", + body=self.post_statement_callback) + + # bind get token + httpretty.register_uri( + method=httpretty.GET, + uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), + body=self.get_token_callback) + + # noinspection PyUnusedLocal + def post_statement_callback(self, request, uri, response_headers): + authorization = request.headers.get("Authorization") + + if authorization and authorization.replace("Bearer ", "") in self.tokens: + return [200, response_headers, json.dumps(self.sample_post_response_data)] + + challenge_id = str(uuid.uuid4()) + token = str(uuid.uuid4()) + self.tokens.add(token) + self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" + return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' + f'x_token_server="{token_server}"', + 'Basic realm': '"Trino"'}, ""] + + # noinspection PyUnusedLocal + def get_token_callback(self, request, uri, response_headers): + challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "") + challenge = self.challenges[challenge_id] + challenge = challenge._replace(attempts=challenge.attempts - 1) + self.challenges[challenge_id] = challenge + if challenge.attempts < 0: + return [404, response_headers, "{}"] + if challenge.attempts == 0: + return [200, response_headers, f'{{"token": "{challenge.token}"}}'] + return [200, response_headers, f'{{"nextUri": "{uri}"}}'] diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 3c26795d..f03b5f79 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -10,11 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import re import threading import time import uuid -from collections import namedtuple from unittest import mock from urllib.parse import urlparse @@ -25,6 +23,9 @@ from requests_kerberos.exceptions import KerberosExchangeError import trino.exceptions +from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ + MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ + SERVER_ADDRESS from trino import constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer from trino.client import TrinoQuery, TrinoRequest, TrinoResult @@ -259,52 +260,6 @@ def long_call(request, uri, headers): httpretty.reset() -SERVER_ADDRESS = "https://coordinator" -REDIRECT_PATH = "oauth2/initiate" -TOKEN_PATH = "oauth2/token" -REDIRECT_RESOURCE = f"{SERVER_ADDRESS}/{REDIRECT_PATH}" -TOKEN_RESOURCE = f"{SERVER_ADDRESS}/{TOKEN_PATH}" - - -class RedirectHandler: - def __init__(self): - self.redirect_server = "" - - def __call__(self, url): - self.redirect_server += url - - -class PostStatementCallback: - def __init__(self, redirect_server, token_server, tokens, sample_post_response_data): - self.redirect_server = redirect_server - self.token_server = token_server - self.tokens = tokens - self.sample_post_response_data = sample_post_response_data - - def __call__(self, request, uri, response_headers): - authorization = request.headers.get("Authorization") - if authorization and authorization.replace("Bearer ", "") in self.tokens: - return [200, response_headers, json.dumps(self.sample_post_response_data)] - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{self.redirect_server}", ' - f'x_token_server="{self.token_server}"', - 'Basic realm': '"Trino"'}, ""] - - -class GetTokenCallback: - def __init__(self, token_server, token, attempts=1): - self.token_server = token_server - self.token = token - self.attempts = attempts - - def __call__(self, request, uri, response_headers): - self.attempts -= 1 - if self.attempts < 0: - return [404, response_headers, "{}"] - if self.attempts == 0: - return [200, response_headers, f'{{"token": "{self.token}"}}'] - return [200, response_headers, f'{{"nextUri": "{self.token_server}"}}'] - - @pytest.mark.parametrize("attempts", [1, 3, 5]) @httprettified def test_oauth2_authentication_flow(attempts, sample_post_response_data): @@ -511,57 +466,6 @@ def test_oauth2_authentication_fail_token_server(http_status, sample_post_respon assert len(_get_token_requests(challenge_id)) == 1 -class MultithreadedTokenServer: - Challenge = namedtuple('Challenge', ['token', 'attempts']) - - def __init__(self, sample_post_response_data, attempts=1): - self.tokens = set() - self.challenges = {} - self.sample_post_response_data = sample_post_response_data - self.attempts = attempts - - # bind post statement - httpretty.register_uri( - method=httpretty.POST, - uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}", - body=self.post_statement_callback) - - # bind get token - httpretty.register_uri( - method=httpretty.GET, - uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), - body=self.get_token_callback) - - # noinspection PyUnusedLocal - def post_statement_callback(self, request, uri, response_headers): - authorization = request.headers.get("Authorization") - - if authorization and authorization.replace("Bearer ", "") in self.tokens: - return [200, response_headers, json.dumps(self.sample_post_response_data)] - - challenge_id = str(uuid.uuid4()) - token = str(uuid.uuid4()) - self.tokens.add(token) - self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts) - redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" - token_server = f"{TOKEN_RESOURCE}/{challenge_id}" - return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", ' - f'x_token_server="{token_server}"', - 'Basic realm': '"Trino"'}, ""] - - # noinspection PyUnusedLocal - def get_token_callback(self, request, uri, response_headers): - challenge_id = uri.replace(f"{TOKEN_RESOURCE}/", "") - challenge = self.challenges[challenge_id] - challenge = challenge._replace(attempts=challenge.attempts - 1) - self.challenges[challenge_id] = challenge - if challenge.attempts < 0: - return [404, response_headers, "{}"] - if challenge.attempts == 0: - return [200, response_headers, f'{{"token": "{challenge.token}"}}'] - return [200, response_headers, f'{{"nextUri": "{uri}"}}'] - - @httprettified def test_multithreaded_oauth2_authentication_flow(sample_post_response_data): redirect_handler = RedirectHandler() @@ -598,31 +502,19 @@ def run(self) -> None: for thread in threads: thread.join() - # should issue only 3 tokens and each thread should get one - assert len(token_server.tokens) == 3 + # should issue only 1 token and each thread should reuse it + assert len(token_server.tokens) == 1 for thread in threads: assert thread.token in token_server.tokens - # should start only 3 challenges and every token should be obtained - assert len(token_server.challenges.keys()) == 3 + # should start only 1 challenge + assert len(token_server.challenges.keys()) == 1 for challenge_id, challenge in token_server.challenges.items(): assert f"{REDIRECT_RESOURCE}/{challenge_id}" in redirect_handler.redirect_server assert challenge.attempts == 0 assert len(_get_token_requests(challenge_id)) == 1 # 3 threads * (10 POST /statement each + 1 replied request by authentication) - assert len(_post_statement_requests()) == 33 - - -def _get_token_requests(challenge_id): - return list(filter( - lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", - httpretty.latest_requests())) - - -def _post_statement_requests(): - return list(filter( - lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, - httpretty.latest_requests())) + assert len(_post_statement_requests()) == 31 @mock.patch("trino.client.TrinoRequest.http") diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 4e6dbbb5..690aa50d 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -9,9 +9,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading +import uuid +from unittest.mock import patch +import httpretty +from httpretty import httprettified from requests import Session -from unittest.mock import patch + +from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \ + GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS +from trino import constants +from trino.auth import OAuth2Authentication from trino.dbapi import connect @@ -39,3 +48,177 @@ def test_http_session_is_defaulted_when_not_specified(mock_client): # THEN request_args, _ = mock_client.TrinoRequest.call_args assert mock_client.TrinoRequest.http.Session.return_value in request_args + + +@httprettified +def test_token_retrieved_once_per_auth_instance(sample_post_response_data): + token = str(uuid.uuid4()) + challenge_id = str(uuid.uuid4()) + + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" + + post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + + # bind post statement + httpretty.register_uri( + method=httpretty.POST, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + body=post_statement_callback) + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + redirect_handler = RedirectHandler() + + with connect( + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS + ) as conn: + conn.cursor().execute("SELECT 1") + conn.cursor().execute("SELECT 2") + conn.cursor().execute("SELECT 3") + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + redirect_handler = RedirectHandler() + + with connect( + "coordinator", + user="test", + auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler), + http_scheme=constants.HTTPS + ) as conn2: + conn2.cursor().execute("SELECT 1") + conn2.cursor().execute("SELECT 2") + conn2.cursor().execute("SELECT 3") + + assert len(_get_token_requests(challenge_id)) == 2 + + +@httprettified +def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data): + token = str(uuid.uuid4()) + challenge_id = str(uuid.uuid4()) + + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" + + post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + + # bind post statement + httpretty.register_uri( + method=httpretty.POST, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + body=post_statement_callback) + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + redirect_handler = RedirectHandler() + + authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) + + with connect( + "coordinator", + user="test", + auth=authentication, + http_scheme=constants.HTTPS + ) as conn: + conn.cursor().execute("SELECT 1") + conn.cursor().execute("SELECT 2") + conn.cursor().execute("SELECT 3") + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + with connect( + "coordinator", + user="test", + auth=authentication, + http_scheme=constants.HTTPS + ) as conn2: + conn2.cursor().execute("SELECT 1") + conn2.cursor().execute("SELECT 2") + conn2.cursor().execute("SELECT 3") + + assert len(_post_statement_requests()) == 7 + assert len(_get_token_requests(challenge_id)) == 1 + + +@httprettified +def test_token_retrieved_once_when_multithreaded(sample_post_response_data): + token = str(uuid.uuid4()) + challenge_id = str(uuid.uuid4()) + + redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}" + token_server = f"{TOKEN_RESOURCE}/{challenge_id}" + + post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + + # bind post statement + httpretty.register_uri( + method=httpretty.POST, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", + body=post_statement_callback) + + # bind get token + get_token_callback = GetTokenCallback(token_server, token) + httpretty.register_uri( + method=httpretty.GET, + uri=token_server, + body=get_token_callback) + + redirect_handler = RedirectHandler() + + authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler) + + conn = connect( + "coordinator", + user="test", + auth=authentication, + http_scheme=constants.HTTPS + ) + + class RunningThread(threading.Thread): + lock = threading.Lock() + + def __init__(self): + super().__init__() + + def run(self) -> None: + with RunningThread.lock: + conn.cursor().execute("SELECT 1") + + threads = [ + RunningThread(), + RunningThread(), + RunningThread() + ] + + # run and join all threads + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert len(_get_token_requests(challenge_id)) == 1 diff --git a/trino/auth.py b/trino/auth.py index 236794e0..e6b4f04c 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -15,11 +15,14 @@ import os import re import threading -from typing import Optional +import webbrowser +from typing import Optional, List, Callable +from urllib.parse import urlparse from requests import Request from requests.auth import AuthBase, extract_cookies_to_jar from requests.utils import parse_dict_header +import importlib import trino.logging from trino.client import exceptions @@ -130,6 +133,7 @@ class _BearerAuth(AuthBase): """ Custom implementation of Authentication class for bearer token """ + def __init__(self, token): self.token = token @@ -156,9 +160,109 @@ def __eq__(self, other): return self.token == other.token -def handle_redirect_auth_url(auth_url): - print("Open the following URL in browser for the external authentication:") - print(auth_url) +class RedirectHandler(metaclass=abc.ABCMeta): + """ + Abstract class for OAuth redirect handlers, inherit from this class to implement your own redirect handler. + """ + + @abc.abstractmethod + def __call__(self, url: str) -> None: + raise NotImplementedError() + + +class ConsoleRedirectHandler(RedirectHandler): + """ + Handler for OAuth redirections to log to console. + """ + + def __call__(self, url: str) -> None: + print("Open the following URL in browser for the external authentication:") + print(url) + + +class WebBrowserRedirectHandler(RedirectHandler): + """ + Handler for OAuth redirections to open in web browser. + """ + + def __call__(self, url: str) -> None: + webbrowser.open_new(url) + + +class CompositeRedirectHandler(RedirectHandler): + """ + Composite handler for OAuth redirect handlers. + """ + + def __init__(self, handlers: List[Callable[[str], None]]): + self.handlers = handlers + + def __call__(self, url: str): + for handler in self.handlers: + handler(url) + + +class _OAuth2TokenCache(metaclass=abc.ABCMeta): + """ + Abstract class for OAuth token cache, inherit from this class to implement your own token cache. + """ + + @abc.abstractmethod + def get_token_from_cache(self, host: str) -> Optional[str]: + pass + + @abc.abstractmethod + def store_token_to_cache(self, host: str, token: str) -> None: + pass + + +class _OAuth2TokenInMemoryCache(_OAuth2TokenCache): + """ + In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache. + """ + + def __init__(self): + self._cache = {} + + def get_token_from_cache(self, host: str) -> Optional[str]: + return self._cache.get(host) + + def store_token_to_cache(self, host: str, token: str) -> None: + self._cache[host] = token + + +class _OAuth2KeyRingTokenCache(_OAuth2TokenCache): + """ + Keyring Token Cache implementation + """ + + def __init__(self): + super().__init__() + try: + self._keyring = importlib.import_module("keyring") + except ImportError: + self._keyring = None + logger.info("keyring module not found. OAuth2 token will not be stored in keyring.") + + def is_keyring_available(self) -> bool: + return self._keyring is not None + + def get_token_from_cache(self, host: str) -> Optional[str]: + try: + return self._keyring.get_password(host, "token") + except self._keyring.errors.NoKeyringError as e: + raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information.") from e + + def store_token_to_cache(self, host: str, token: str) -> None: + try: + # keyring is installed, so we can store the token for reuse within multiple threads + self._keyring.set_password(host, "token", token) + except self._keyring.errors.NoKeyringError as e: + raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " + "detected, check https://pypi.org/project/keyring/ for more " + "information.") from e class _OAuth2TokenBearer(AuthBase): @@ -168,14 +272,20 @@ class _OAuth2TokenBearer(AuthBase): MAX_OAUTH_ATTEMPTS = 5 _BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) - def __init__(self, http_session, redirect_auth_url_handler=handle_redirect_auth_url): + def __init__(self, redirect_auth_url_handler: Callable[[str], None]): self._redirect_auth_url = redirect_auth_url_handler - self._thread_local = threading.local() - http_session.hooks['response'].append(self._authenticate) + keyring_cache = _OAuth2KeyRingTokenCache() + self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache() + self._token_lock = threading.Lock() + self._inside_oauth_attempt_lock = threading.Lock() + self._inside_oauth_attempt_blocker = threading.Event() def __call__(self, r): - if hasattr(self._thread_local, 'token') and self._thread_local.token: - r.headers['Authorization'] = "Bearer " + self._thread_local.token + host = self._determine_host(r.url) + token = self._get_token_from_cache(host) + + if token is not None: + r.headers['Authorization'] = "Bearer " + token r.register_hook('response', self._authenticate) @@ -185,7 +295,23 @@ def _authenticate(self, response, **kwargs): if not 400 <= response.status_code < 500: return response - # we have to handle the authentication, may be token the token expired or it wasn't there at all + acquired = self._inside_oauth_attempt_lock.acquire(blocking=False) + if acquired: + try: + # Lock is acquired, attempt the OAuth2 flow + self._attempt_oauth(response, **kwargs) + self._inside_oauth_attempt_blocker.set() + finally: + self._inside_oauth_attempt_lock.release() + self._inside_oauth_attempt_blocker.clear() + else: + # Lock is not acquired, we are already in the OAuth2 flow, so we block until OAuth2 flow is finished. + self._inside_oauth_attempt_blocker.wait() + + return self._retry_request(response, **kwargs) + + def _attempt_oauth(self, response, **kwargs): + # we have to handle the authentication, may be token the token expired, or it wasn't there at all auth_info = response.headers.get('WWW-Authenticate') if not auth_info: raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.") @@ -203,8 +329,6 @@ def _authenticate(self, response, **kwargs): if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") - self._thread_local.token_server = token_server - # tell app that use this url to proceed with the authentication self._redirect_auth_url(auth_server) @@ -213,15 +337,19 @@ def _authenticate(self, response, **kwargs): response.content response.close() - self._thread_local.token = self._get_token(token_server, response, **kwargs) - return self._retry_request(response, **kwargs) + token = self._get_token(token_server, response, **kwargs) + + request = response.request + host = self._determine_host(request.url) + self._store_token_to_cache(host, token) def _retry_request(self, response, **kwargs): request = response.request.copy() extract_cookies_to_jar(request._cookies, response.request, response.raw) request.prepare_cookies(request._cookies) - request.headers['Authorization'] = "Bearer " + self._thread_local.token + host = self._determine_host(response.request.url) + request.headers['Authorization'] = "Bearer " + self._get_token_from_cache(host) retry_response = response.connection.send(request, **kwargs) retry_response.history.append(response) retry_response.request = request @@ -251,13 +379,29 @@ def _get_token(self, token_server, response, **kwargs): raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") + def _get_token_from_cache(self, host: str) -> Optional[str]: + with self._token_lock: + return self._token_cache.get_token_from_cache(host) + + def _store_token_to_cache(self, host: str, token: str) -> None: + with self._token_lock: + self._token_cache.store_token_to_cache(host, token) + + @staticmethod + def _determine_host(url) -> Optional[str]: + return urlparse(url).hostname + class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler=handle_redirect_auth_url): + def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([ + WebBrowserRedirectHandler(), + ConsoleRedirectHandler() + ])): self._redirect_auth_url = redirect_auth_url_handler + self._bearer = _OAuth2TokenBearer(self._redirect_auth_url) def set_http_session(self, http_session): - http_session.auth = _OAuth2TokenBearer(http_session, self._redirect_auth_url) + http_session.auth = self._bearer return http_session def get_exceptions(self):