Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize the oauth_api_connector to accept authorization kwargs #918

Merged
merged 10 commits into from
Feb 20, 2024
46 changes: 15 additions & 31 deletions parsons/catalist/catalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from typing import Optional, Union, Dict, List
from zipfile import ZipFile

import requests
from parsons.etl import Table
from parsons.sftp import SFTP
from parsons.utilities.api_connector import APIConnector
from parsons.utilities.oauth_api_connector import OAuth2APIConnector

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,36 +61,19 @@ def __init__(
client_secret: str,
sftp_username: str,
sftp_password: str,
client_audience: Optional[str] = None,
) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.fetch_token()
self.connection = APIConnector("http://api.catalist.us/mapi/")
self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password)

@property
def token(self) -> str:
"""If token is not yet fetched or has expired, fetch new token."""
if not (self._token and time.time() < self._token_expired_at):
self.fetch_token()
return self._token

def fetch_token(self) -> None:
"""Fetch auth0 token to be used with Catalist API."""
url = "https://auth.catalist.us/oauth/token"
payload = {
"grant_type": "client_credentials",
"audience": "catalist_api_m_prod",
}
response = requests.post(
url, json=payload, auth=(self.client_id, self.client_secret)
self.connection = OAuth2APIConnector(
"https://api.catalist.us/mapi/",
client_id=client_id,
client_secret=client_secret,
authorization_kwargs={"audience": client_audience or "catalist_api_m_prod"},
token_url="https://auth.catalist.us/oauth/token",
auto_refresh_url="https://auth.catalist.us/oauth/token",
)
data = response.json()

self._token = data["access_token"]
self._token_expired_at = time.time() + data["expires_in"]

logger.info("Token refreshed.")
self.sftp = SFTP("t.catalist.us", sftp_username, sftp_password)

def load_table_to_sftp(
self, table: Table, input_subfolder: Optional[str] = None
Expand Down Expand Up @@ -241,7 +223,9 @@ def upload(
endpoint = "/".join(endpoint_params)

# Assemble query parameters
query_params: Dict[str, Union[str, int]] = {"token": self.token}
query_params: Dict[str, Union[str, int]] = {
"token": self.connection.token["access_token"]
}
if copy_to_sandbox:
query_params["copyToSandbox"] = "true"
if static_values:
Expand Down Expand Up @@ -308,7 +292,7 @@ def action(

logger.debug(f"Executing request to endpoint {self.connection.uri + endpoint}")

query_params = {"token": self.token}
query_params = {"token": self.connection.token["access_token"]}
if copy_to_sandbox:
query_params["copyToSandbox"] = "true"
if export_filename_suffix:
Expand All @@ -323,7 +307,7 @@ def action(
def status(self, id: str) -> dict:
"""Check status of a match job."""
endpoint = "/".join(["status", "id", id])
query_params = {"token": self.token}
query_params = {"token": self.connection.token["access_token"]}
result = self.connection.get_request(endpoint, params=query_params)
return result

Expand Down
9 changes: 8 additions & 1 deletion parsons/controlshift/controlshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Controlshift(object):
`Args:`
hostname: str
The URL for the homepage/login page of the organization's Controlshift
instance (e.g. demo.controlshift.app). Not required if
instance (e.g. https://demo.controlshift.app). Not required if
``CONTROLSHIFT_HOSTNAME`` env variable is set.
client_id: str
The Client ID for your REST API Application. Not required if
Expand All @@ -27,6 +27,13 @@ class Controlshift(object):
def __init__(self, hostname=None, client_id=None, client_secret=None):

self.hostname = check_env.check("CONTROLSHIFT_HOSTNAME", hostname)

# Hostname must start with 'https://'
if self.hostname.startswith("http://"):
self.hostname = self.hostname.replace("http://", "https://")
if not self.hostname.startswith("https://"):
self.hostname = "https://" + self.hostname

token_url = f"{self.hostname}/oauth/token"
self.client = OAuth2APIConnector(
self.hostname,
Expand Down
54 changes: 31 additions & 23 deletions parsons/utilities/oauth_api_connector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import urllib.parse
from typing import Dict, Optional

from oauthlib.oauth2 import BackendApplicationClient
from requests_oauthlib import OAuth2Session
from parsons.utilities.api_connector import APIConnector
import urllib.parse
from requests_oauthlib import OAuth2Session


class OAuth2APIConnector(APIConnector):
Expand All @@ -13,16 +15,6 @@ class OAuth2APIConnector(APIConnector):
`Args:`
uri: str
The base uri for the api. Must include a trailing '/' (e.g. ``http://myapi.com/v1/``)
headers: dict
The request headers
auth: dict
The request authorization parameters
pagination_key: str
The name of the key in the response json where the pagination url is
located. Required for pagination.
data_key: str
The name of the key in the response json where the data is contained. Required
if the data is nested in the response json
client_id: str
The client id for acquiring and exchanging tokens from the OAuth2 application
client_secret: str
Expand All @@ -31,40 +23,56 @@ class OAuth2APIConnector(APIConnector):
The URL for acquiring new tokens from the OAuth2 Application
auto_refresh_url: str
If provided, the URL for refreshing tokens from the OAuth2 Application
headers: dict
The request headers
pagination_key: str
The name of the key in the response json where the pagination url is
located. Required for pagination.
data_key: str
The name of the key in the response json where the data is contained. Required
if the data is nested in the response json
`Returns`:
OAuthAPIConnector class
"""

def __init__(
self,
uri,
headers=None,
auth=None,
pagination_key=None,
data_key=None,
client_id=None,
client_secret=None,
token_url=None,
auto_refresh_url=None,
uri: str,
client_id: str,
client_secret: str,
token_url: str,
auto_refresh_url: Optional[str],
headers: Optional[Dict[str, str]] = None,
pagination_key: Optional[str] = None,
data_key: Optional[str] = None,
grant_type: str = "client_credentials",
authorization_kwargs: Optional[Dict[str, str]] = None,
):
super().__init__(
uri,
headers=headers,
auth=auth,
pagination_key=pagination_key,
data_key=data_key,
)

if not authorization_kwargs:
authorization_kwargs = {}

client = BackendApplicationClient(client_id=client_id)
client.grant_type = grant_type
oauth = OAuth2Session(client=client)
self.token = oauth.fetch_token(
token_url=token_url, client_id=client_id, client_secret=client_secret
token_url=token_url,
client_id=client_id,
client_secret=client_secret,
**authorization_kwargs
)
self.client = OAuth2Session(
client_id,
token=self.token,
auto_refresh_url=auto_refresh_url,
token_updater=self.token_saver,
auto_refresh_kwargs=authorization_kwargs,
)

def request(self, url, req_type, json=None, data=None, params=None):
Expand Down
75 changes: 17 additions & 58 deletions parsons/zoom/zoom.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from parsons.utilities import check_env
from parsons.utilities.api_connector import APIConnector
from parsons.utilities.oauth_api_connector import OAuth2APIConnector
from parsons import Table
import logging
import jwt
import datetime
import uuid

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,55 +30,16 @@ def __init__(self, account_id=None, client_id=None, client_secret=None):
self.client_id = check_env.check("ZOOM_CLIENT_ID", client_id)
self.__client_secret = check_env.check("ZOOM_CLIENT_SECRET", client_secret)

self.client = APIConnector(uri=ZOOM_URI)

access_token = self.__generate_access_token()

self.client.headers = {
"Authorization": f"Bearer {access_token}",
"Content-type": "application/json",
}

def __generate_access_token(self) -> str:
"""
Uses Zoom's OAuth callback URL to generate an access token to query the Zoom API

`Returns`:
String representation of access token
"""

temp_client = APIConnector(
uri=ZOOM_URI, auth=(self.client_id, self.__client_secret)
self.client = OAuth2APIConnector(
uri=ZOOM_URI,
client_id=self.client_id,
client_secret=self.__client_secret,
token_url=ZOOM_AUTH_CALLBACK,
auto_refresh_url=ZOOM_AUTH_CALLBACK,
grant_type="account_credentials",
authorization_kwargs={"account_id": self.account_id},
)

resp = temp_client.post_request(
ZOOM_AUTH_CALLBACK,
data={
"grant_type": "account_credentials",
"account_id": self.account_id,
},
)

return resp["access_token"]

def __refresh_header_token(self):
"""
NOTE: This function is deprecated as Zoom's API moves to an OAuth strategy on 9/1

Generate a token that is valid for 30 seconds and update header. Full documentation
on JWT generation using Zoom API: https://marketplace.zoom.us/docs/guides/auth/jwt
"""

payload = {
"iss": self.api_key,
"exp": int(datetime.datetime.now().timestamp() + 30),
}
token = jwt.encode(payload, self.api_secret, algorithm="HS256")
self.client.headers = {
"authorization": f"Bearer {token}",
"content-type": "application/json",
}

def _get_request(self, endpoint, data_key, params=None, **kwargs):
"""
TODO: Consider increasing default page size.
Expand Down Expand Up @@ -353,7 +312,7 @@ def get_meeting_poll_metadata(self, meeting_id, poll_id) -> Table:
endpoint = f"meetings/{meeting_id}/polls/{poll_id}"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for poll ID {poll_id}")
return Table(tbl)

Expand All @@ -380,7 +339,7 @@ def get_meeting_all_polls_metadata(self, meeting_id) -> Table:
endpoint = f"meetings/{meeting_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="polls")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for meeting ID {meeting_id}")
return Table(tbl)

Expand All @@ -405,7 +364,7 @@ def get_past_meeting_poll_metadata(self, meeting_id) -> Table:
endpoint = f"past_meetings/{meeting_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for meeting ID {meeting_id}")
return Table(tbl)

Expand All @@ -432,7 +391,7 @@ def get_webinar_poll_metadata(self, webinar_id, poll_id) -> Table:
endpoint = f"webinars/{webinar_id}/polls/{poll_id}"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for poll ID {poll_id}")
return Table(tbl)

Expand All @@ -459,7 +418,7 @@ def get_webinar_all_polls_metadata(self, webinar_id) -> Table:
endpoint = f"webinars/{webinar_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="polls")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for webinar ID {webinar_id}")
return Table(tbl)

Expand All @@ -484,7 +443,7 @@ def get_past_webinar_poll_metadata(self, webinar_id) -> Table:
endpoint = f"past_webinars/{webinar_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for webinar ID {webinar_id}")
return Table(tbl)

Expand All @@ -502,7 +461,7 @@ def get_meeting_poll_results(self, meeting_id) -> Table:
endpoint = f"report/meetings/{meeting_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for meeting ID {meeting_id}")
return Table(tbl)

Expand All @@ -520,7 +479,7 @@ def get_webinar_poll_results(self, webinar_id) -> Table:
endpoint = f"report/webinars/{webinar_id}/polls"
tbl = self._get_request(endpoint=endpoint, data_key="questions")

if type(tbl) == dict:
if isinstance(tbl, dict):
logger.debug(f"No poll data returned for webinar ID {webinar_id}")
return Table(tbl)

Expand Down
Loading