Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add organization ID to session when appropriate #294

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion geti_sdk/data_models/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def cancel(self, session: GetiSession) -> "Job":
:return: Job with updated status
"""
try:
session.get_rest_response(url=self.relative_url, method="DELETE")
session.get_rest_response(
url=self.relative_url, method="DELETE", allow_text_response=True
)
self.status.state = JobState.CANCELLED
except GetiRequestException as error:
if error.status_code == 404:
Expand Down
88 changes: 72 additions & 16 deletions geti_sdk/http_session/geti_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@
from requests.structures import CaseInsensitiveDict
from urllib3.exceptions import InsecureRequestWarning

from geti_sdk.platform_versions import GetiVersion
from geti_sdk.platform_versions import GETI_18_VERSION, GetiVersion

from .exception import GetiRequestException
from .server_config import LEGACY_API_VERSION, ServerCredentialConfig, ServerTokenConfig

CSRF_COOKIE_NAME = "_oauth2_proxy_csrf"
PROXY_COOKIE_NAME = "_oauth2_proxy"

INITIAL_HEADERS = {"Connection": "keep-alive", "Upgrade-Insecure-Requests": "1"}
# INITIAL_HEADERS = {"Connection": "keep-alive", "Upgrade-Insecure-Requests": "1"}
INITIAL_HEADERS = {"Upgrade-Insecure-Requests": "1"}
SUCCESS_STATUS_CODES = [200, 201]


Expand Down Expand Up @@ -118,6 +119,7 @@ def _acquire_access_token(self) -> str:
data={"service_id": self.config.token},
contenttype="json",
allow_reauthentication=False,
include_organization_id=False,
)
except GetiRequestException as error:
if error.status_code == 401:
Expand All @@ -141,7 +143,9 @@ def _follow_login_redirects(self, response: Response) -> str:
redirect_url = response.next.url
redirected = self.get(redirect_url, allow_redirects=False, **self._proxies)
proxy_csrf = redirected.cookies.get(CSRF_COOKIE_NAME, None)
if proxy_csrf:
if proxy_csrf is None:
proxy_csrf = response.cookies.get(CSRF_COOKIE_NAME, None)
if proxy_csrf is not None:
self._cookies[CSRF_COOKIE_NAME] = proxy_csrf
return self._follow_login_redirects(redirected)
else:
Expand All @@ -164,6 +168,10 @@ def authenticate(self, verbose: bool = True):

:param verbose: True to print progress output, False to suppress output
"""
if self.logged_in:
logging.info("Already logged in, authentication is skipped")
return
self.cookies.clear()
try:
login_path = self._get_initial_login_url()
except requests.exceptions.SSLError as error:
Expand Down Expand Up @@ -202,8 +210,16 @@ def authenticate(self, verbose: bool = True):
"The cluster responded to the request, but authentication failed. "
"Please verify that you have provided correct credentials."
)
cookie = {PROXY_COOKIE_NAME: previous_response.cookies.get(PROXY_COOKIE_NAME)}
self._cookies.update(cookie)
proxy_cookie = previous_response.cookies.get(PROXY_COOKIE_NAME)
if proxy_cookie is not None:
cookie = {PROXY_COOKIE_NAME: proxy_cookie}
self._cookies.update(cookie)
else:
logging.warning(
f"Authentication appears to have failed! No valid oauth cookie "
f"obtained. Invalid response received from server. Status code: "
f"{response.status_code}"
)
if verbose:
logging.info("Authentication successful. Cookie received.")
self.logged_in = True
Expand All @@ -215,6 +231,8 @@ def get_rest_response(
contenttype: str = "json",
data=None,
allow_reauthentication: bool = True,
include_organization_id: bool = True,
allow_text_response: bool = False,
) -> Union[Response, dict, list]:
"""
Return the REST response from a request to `url` with `method`.
Expand All @@ -227,13 +245,24 @@ def get_rest_response(
:param allow_reauthentication: True to handle authentication errors
by attempting to re-authenticate. If set to False, such errors
will be raised instead.
:param include_organization_id: True to include the organization ID in the base
URL. Can be set to False for accessing certain internal endpoints that do
not require an organization ID, but do require error handling.
:param allow_text_response: False to trigger error handling when the server
returns a response with text/html content. This can happen in some cases
when authentication has expired. However, some endpoints are designed to
return text responses, for those endpoints this parameter should be set to
True
"""
if url.startswith(self.config.api_pattern):
url = url[len(self.config.api_pattern) :]

self._update_headers_for_content_type(content_type=contenttype)

requesturl = f"{self.config.base_url}{url}"
if not include_organization_id:
requesturl = f"{self.config.base_url}{url}"
else:
requesturl = f"{self.base_url}{url}"

if method == "POST" or method == "PUT":
if contenttype == "json":
Expand Down Expand Up @@ -269,17 +298,19 @@ def get_rest_response(
f"be verified. \n Full error description: {error.args[-1]}"
)

response_content_type = response.headers.get("Content-Type", [])
if (
response.status_code not in SUCCESS_STATUS_CODES
or "text/html" in response.headers.get("Content-Type", [])
or "text/html" in response_content_type
):
response = self._handle_error_response(
response=response,
request_params=request_params,
request_data=kw_data_arg,
allow_reauthentication=allow_reauthentication,
content_type=contenttype,
)
if not ("text/html" in response_content_type and allow_text_response):
response = self._handle_error_response(
response=response,
request_params=request_params,
request_data=kw_data_arg,
allow_reauthentication=allow_reauthentication,
content_type=contenttype,
)

if response.headers.get("Content-Type", None) == "application/json":
result = response.json()
Expand Down Expand Up @@ -353,10 +384,14 @@ def _get_product_info_and_set_api_version(self) -> Dict[str, str]:
:return: Dictionary containing the product info.
"""
try:
product_info = self.get_rest_response("product_info", "GET")
product_info = self.get_rest_response(
"product_info", "GET", include_organization_id=False
)
except GetiRequestException:
self.config.api_version = LEGACY_API_VERSION
product_info = self.get_rest_response("product_info", "GET")
product_info = self.get_rest_response(
"product_info", "GET", include_organization_id=False
)
return product_info

def __exit__(self, exc_type, exc_value, traceback):
Expand Down Expand Up @@ -409,6 +444,7 @@ def _handle_error_response(
if response.status_code in [200, 401, 403] and allow_reauthentication:
# Authentication has likely expired, re-authenticate
logging.info("Authentication may have expired, re-authenticating...")
self.logged_in = False
if not self.use_token:
self.authenticate(verbose=False)
logging.info("Authentication complete.")
Expand Down Expand Up @@ -471,3 +507,23 @@ def _update_headers_for_content_type(self, content_type: str) -> None:
self.headers.pop("Content-Type", None)
elif content_type == "zip":
self.headers.update({"Content-Type": "application/zip"})

@property
def base_url(self) -> str:
"""
Return the base URL to the Intel Geti server. If the server is running
Geti v1.9 or later, the organization ID will be included in the URL
"""
if self.version <= GETI_18_VERSION:
return self.config.base_url
else:
org_id = self.get_organization_id()
return f"{self.config.base_url}organizations/{org_id}/"

def get_organization_id(self) -> str:
"""
Return the organization ID associated with the user and host information configured
in this Session
"""
default_org_id = "000000000000000000000001"
return default_org_id
1 change: 1 addition & 0 deletions geti_sdk/platform_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,4 @@ def is_geti(self) -> bool:
GETI_11_VERSION = GetiVersion("1.1.0-release-20221125121144")
GETI_12_VERSION = GetiVersion("1.2.0-release-20230101120000")
GETI_15_VERSION = GetiVersion("1.5.0-release-20230504111017")
GETI_18_VERSION = GetiVersion("1.8.0-release-20231018022911")
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
4 changes: 2 additions & 2 deletions tests/fixtures/cassettes/TestGetiSession.test_logout.cassette
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading
Loading