diff --git a/parsons/catalist/catalist.py b/parsons/catalist/catalist.py index 4779f70648..b090c91793 100644 --- a/parsons/catalist/catalist.py +++ b/parsons/catalist/catalist.py @@ -12,10 +12,9 @@ from typing import Optional, Union, Dict, List from zipfile import ZipFile -import requests from parsons.etl import Table from parsons.sftp import SFTP -from parsons.utilities.api_connector import APIConnector +from parsons.utilities.oauth_api_connector import OAuth2APIConnector logger = logging.getLogger(__name__) @@ -62,36 +61,19 @@ def __init__( client_secret: str, sftp_username: str, sftp_password: str, + client_audience: Optional[str] = None, ) -> None: self.client_id = client_id self.client_secret = client_secret - self.fetch_token() - self.connection = APIConnector("http://api.catalist.us/mapi/") - self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password) - - @property - def token(self) -> str: - """If token is not yet fetched or has expired, fetch new token.""" - if not (self._token and time.time() < self._token_expired_at): - self.fetch_token() - return self._token - - def fetch_token(self) -> None: - """Fetch auth0 token to be used with Catalist API.""" - url = "https://auth.catalist.us/oauth/token" - payload = { - "grant_type": "client_credentials", - "audience": "catalist_api_m_prod", - } - response = requests.post( - url, json=payload, auth=(self.client_id, self.client_secret) + self.connection = OAuth2APIConnector( + "https://api.catalist.us/mapi/", + client_id=client_id, + client_secret=client_secret, + authorization_kwargs={"audience": client_audience or "catalist_api_m_prod"}, + token_url="https://auth.catalist.us/oauth/token", + auto_refresh_url="https://auth.catalist.us/oauth/token", ) - data = response.json() - - self._token = data["access_token"] - self._token_expired_at = time.time() + data["expires_in"] - - logger.info("Token refreshed.") + self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password) def load_table_to_sftp( self, table: Table, input_subfolder: Optional[str] = None @@ -241,7 +223,9 @@ def upload( endpoint = "/".join(endpoint_params) # Assemble query parameters - query_params: Dict[str, Union[str, int]] = {"token": self.token} + query_params: Dict[str, Union[str, int]] = { + "token": self.connection.token["access_token"] + } if copy_to_sandbox: query_params["copyToSandbox"] = "true" if static_values: @@ -308,7 +292,7 @@ def action( logger.debug(f"Executing request to endpoint {self.connection.uri + endpoint}") - query_params = {"token": self.token} + query_params = {"token": self.connection.token["access_token"]} if copy_to_sandbox: query_params["copyToSandbox"] = "true" if export_filename_suffix: @@ -323,7 +307,7 @@ def action( def status(self, id: str) -> dict: """Check status of a match job.""" endpoint = "/".join(["status", "id", id]) - query_params = {"token": self.token} + query_params = {"token": self.connection.token["access_token"]} result = self.connection.get_request(endpoint, params=query_params) return result