Skip to content

Commit

Permalink
Use oauth_api_connector in Catalist connector
Browse files Browse the repository at this point in the history
  • Loading branch information
austinweisgrau committed Nov 15, 2023
1 parent 39ad845 commit 857d179
Showing 1 changed file with 15 additions and 31 deletions.
46 changes: 15 additions & 31 deletions parsons/catalist/catalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from typing import Optional, Union, Dict, List
from zipfile import ZipFile

import requests
from parsons.etl import Table
from parsons.sftp import SFTP
from parsons.utilities.api_connector import APIConnector
from parsons.utilities.oauth_api_connector import OAuth2APIConnector

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,36 +61,19 @@ def __init__(
client_secret: str,
sftp_username: str,
sftp_password: str,
client_audience: Optional[str] = None,
) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.fetch_token()
self.connection = APIConnector("http://api.catalist.us/mapi/")
self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password)

@property
def token(self) -> str:
"""If token is not yet fetched or has expired, fetch new token."""
if not (self._token and time.time() < self._token_expired_at):
self.fetch_token()
return self._token

def fetch_token(self) -> None:
"""Fetch auth0 token to be used with Catalist API."""
url = "https://auth.catalist.us/oauth/token"
payload = {
"grant_type": "client_credentials",
"audience": "catalist_api_m_prod",
}
response = requests.post(
url, json=payload, auth=(self.client_id, self.client_secret)
self.connection = OAuth2APIConnector(
"https://api.catalist.us/mapi/",
client_id=client_id,
client_secret=client_secret,
authorization_kwargs={"audience": client_audience or "catalist_api_m_prod"},
token_url="https://auth.catalist.us/oauth/token",
auto_refresh_url="https://auth.catalist.us/oauth/token",
)
data = response.json()

self._token = data["access_token"]
self._token_expired_at = time.time() + data["expires_in"]

logger.info("Token refreshed.")
self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password)

def load_table_to_sftp(
self, table: Table, input_subfolder: Optional[str] = None
Expand Down Expand Up @@ -241,7 +223,9 @@ def upload(
endpoint = "/".join(endpoint_params)

# Assemble query parameters
query_params: Dict[str, Union[str, int]] = {"token": self.token}
query_params: Dict[str, Union[str, int]] = {
"token": self.connection.token["access_token"]
}
if copy_to_sandbox:
query_params["copyToSandbox"] = "true"
if static_values:
Expand Down Expand Up @@ -308,7 +292,7 @@ def action(

logger.debug(f"Executing request to endpoint {self.connection.uri + endpoint}")

query_params = {"token": self.token}
query_params = {"token": self.connection.token["access_token"]}
if copy_to_sandbox:
query_params["copyToSandbox"] = "true"
if export_filename_suffix:
Expand All @@ -323,7 +307,7 @@ def action(
def status(self, id: str) -> dict:
"""Check status of a match job."""
endpoint = "/".join(["status", "id", id])
query_params = {"token": self.token}
query_params = {"token": self.connection.token["access_token"]}
result = self.connection.get_request(endpoint, params=query_params)
return result

Expand Down

0 comments on commit 857d179

Please sign in to comment.