Skip to content

Commit

Permalink
ENG-4971: Extending client object creation to support either a full U… (
Browse files Browse the repository at this point in the history
#78)

* ENG-4971: Extending client object creation to support either a full URL or a tenant

* Addressing review comments

* Added new unit test and handling special cases to address review comments

* Fixining static type check issues

* Addressing some more review comments. Trying to fix the codecove issue

* Trying to address the codecove issue

* Undone circleci related changes.

* Addressing code review comments

* Ignoring the user provided port in the URL

* Fixing static type errors

* Addressing Code review comments:
	1) Added port to specify port in the URL.
	2) Asserting after login returns.

* Fixed static type check error
  • Loading branch information
JMkrish authored May 14, 2024
1 parent 8648fa3 commit 5b11d26
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,4 @@ jobs:
# Store MyPy type checking artifacts (e.g., HTML reports)
- store_artifacts:
name: MyPy Artifacts
path: mypy-results/mypy-html
path: mypy-results/mypy-html
5 changes: 4 additions & 1 deletion smsdk/Auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self, client):
# Setup session and store host
self.requests = requests
self.host = get_url(
client.config["protocol"], client.tenant, client.config["site.domain"]
client.config["protocol"],
client.tenant,
client.config["site.domain"],
client.config["port"],
)
self.session.headers = default_headers()

Expand Down
85 changes: 57 additions & 28 deletions smsdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ def dict_to_df(data, normalize=True):
class Client(ClientV0):
"""Connection point to the Sight Machine platform to retrieve data"""

session = None
tenant = None
config = None

def __init__(
self, tenant: str, site_domain: str = "sightmachine.io", protocol: str = "https"
):
Expand All @@ -143,16 +139,7 @@ def __init__(
:type site_domain: :class:`string`
"""

self.tenant = tenant

# Handle internal configuration
self.config = {}
self.config["protocol"] = protocol
self.config["site.domain"] = site_domain

# Setup Authenticator
self.auth = Authenticator(self)
self.session = self.auth.session
super().__init__(tenant, site_domain=site_domain, protocol=protocol)

@version_check_decorator
def select_db_schema(self, schema_name):
Expand All @@ -177,7 +164,10 @@ def get_data_v1(self, ename, util_name, normalize=True, *args, **kwargs):
:return: pandas dataframe
"""
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)

df = pd.DataFrame()
Expand Down Expand Up @@ -290,7 +280,10 @@ def get_parts(
def get_kpis(self, **kwargs):
kpis = smsdkentities.get("kpi")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return kpis(self.session, base_url).get_kpis(**kwargs)

Expand Down Expand Up @@ -340,7 +333,10 @@ def get_machine_source_from_clean_name(self, kwargs):
def get_kpis_for_asset(self, **kwargs):
kpis = smsdkentities.get("kpi")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
if "machine_type" in kwargs["asset_selection"]:
# updating kwargs with machine_type's system name in case of user provides display name.
Expand Down Expand Up @@ -386,7 +382,10 @@ def get_kpi_data_viz(
if time_selection:
kwargs["time_selection"] = time_selection
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)

if "asset_selection" in kwargs and "machine_type" in kwargs["asset_selection"]:
Expand All @@ -408,7 +407,10 @@ def get_kpi_data_viz(
def get_type_from_machine(self, machine_source=None, **kwargs):
machine = smsdkentities.get("machine")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return machine(self.session, base_url).get_type_from_machine_name(
machine_source, **kwargs
Expand All @@ -426,7 +428,10 @@ def get_machine_schema(
machineType = smsdkentities.get("machine_type")
machine_type = self.get_type_from_machine(machine_source)
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
fields = machineType(self.session, base_url).get_fields(machine_type, **kwargs)
fields = [
Expand Down Expand Up @@ -462,7 +467,10 @@ def get_fields_of_machine_type(
):
machineType = smsdkentities.get("machine_type")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
fields = machineType(self.session, base_url).get_fields(machine_type, **kwargs)
fields = [
Expand All @@ -485,7 +493,10 @@ def get_cookbooks(self, **kwargs):
"""
cookbook = smsdkentities.get("cookbook")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return cookbook(self.session, base_url).get_cookbooks(**kwargs)

Expand All @@ -499,7 +510,10 @@ def get_cookbook_top_results(self, recipe_group_id=None, limit=10, **kwargs):
"""
cookbook = smsdkentities.get("cookbook")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return cookbook(self.session, base_url).get_top_results(
recipe_group_id, limit, **kwargs
Expand All @@ -515,7 +529,10 @@ def get_cookbook_current_value(self, variables=[], minutes=1440, **kwargs):
"""
cookbook = smsdkentities.get("cookbook")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return cookbook(self.session, base_url).get_current_value(
variables, minutes, **kwargs
Expand Down Expand Up @@ -553,7 +570,10 @@ def get_lines(self, **kwargs):
"""
lines = smsdkentities.get("line")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
return lines(self.session, base_url).get_lines(**kwargs)

Expand Down Expand Up @@ -581,7 +601,10 @@ def get_line_data(
"""
lines = smsdkentities.get("line")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)

asset_selection = []
Expand Down Expand Up @@ -621,7 +644,10 @@ def create_share_link(
):
dataViz = smsdkentities.get("dataViz")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
if assets and model == "cycle" or assets and model == "kpi":
machine_types = []
Expand Down Expand Up @@ -762,7 +788,10 @@ def get_raw_data(
):
raw_data = smsdkentities.get("raw_data")
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
select = [{"name": field} for field in fields]
kwargs["asset_selection"] = {"raw_data_table": raw_data_table}
Expand Down
98 changes: 87 additions & 11 deletions smsdk/client_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from datetime import datetime
import logging
import functools
from urllib.parse import urlparse

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,6 +75,57 @@ def dict_to_df(data, normalize=True):
return df


def convert_to_valid_url(
input_url: str,
default_domain: str = "sightmachine.io",
default_protocol: str = "https",
):
port = ""
path = ""

# Check if the input URL has a protocol specified
if "://" in input_url:
protocol, rest_url = input_url.split("://", 1)
else:
protocol = default_protocol
rest_url = input_url

if not protocol:
protocol = default_protocol

# Split the remaining URL to check if domain is specified
parts = rest_url.split("/", 1)

if len(parts) == 1:
domain = parts[0]
path = ""
else:
domain, path = parts

# Check if the domain has a port specified
splits = domain.split(":", 1)

if len(splits) == 2:
domain, port = splits

# Check if the domain has a TLD or not
if "." not in domain:
domain = f"{domain}.{default_domain}"

# Construct the valid URL
valid_url = f"{protocol}://{domain}"

if port:
valid_url = f"{valid_url}:{port}"
# log.warning(f"Ignored the user specified port.")

if path:
# valid_url = f"{valid_url}/{path}"
log.warning(f"Ignored the user specified path.")

return valid_url


# We don't have a downtime schema, so hard code one
downmap = {
"machine__source": "Machine",
Expand Down Expand Up @@ -103,9 +155,11 @@ class ClientV0(object):

session = None
tenant = None
config = None
config = {}

def __init__(self, tenant, site_domain="sightmachine.io"):
def __init__(
self, tenant: str, site_domain: str = "sightmachine.io", protocol: str = "https"
):
"""
Initialize the client.
Expand All @@ -117,18 +171,31 @@ def __init__(self, tenant, site_domain="sightmachine.io"):
:type site_domain: :class:`string`
"""

self.tenant = tenant
port = None
if tenant:
# Convert the input tenant into a valid url
url = convert_to_valid_url(
tenant, default_domain=site_domain, default_protocol=protocol
)

# Parse the input string
parsed_uri = urlparse(url)

tenant = parsed_uri.netloc.split(".", 1)[0]
protocol = parsed_uri.scheme
site_domain = parsed_uri.netloc.split(":")[0].replace(f"{tenant}.", "")

# Handle internal configuration
self.config = {}
self.config["protocol"] = "https"
self.config["site.domain"] = site_domain
# Extract port
port = parsed_uri.port

self.tenant = tenant
self.config = {"protocol": protocol, "site.domain": site_domain, "port": port}

# Setup Authenticator
self.auth = Authenticator(self)
self.session = self.auth.session

def login(self, method: str, **kwargs: t_.Any) -> None:
def login(self, method: str, **kwargs: t_.Any) -> bool:
"""
Authenticate with the configured tenant and user credentials.
Expand Down Expand Up @@ -203,7 +270,10 @@ def get_data(self, ename, util_name, normalize=True, *args, **kwargs):
:return: pandas dataframe
"""
base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)

df = pd.DataFrame()
Expand Down Expand Up @@ -1412,7 +1482,10 @@ def get_cycle_count(
}

base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
cls = smsdkentities.get("dataviz_cycle")(self.session, base_url)

Expand Down Expand Up @@ -1523,7 +1596,10 @@ def get_part_count(self, start_time="", end_time="", part_type=None, **kwargs):
}

base_url = get_url(
self.config["protocol"], self.tenant, self.config["site.domain"]
self.config["protocol"],
self.tenant,
self.config["site.domain"],
self.config["port"],
)
cls = smsdkentities.get("dataviz_part")(self.session, base_url)

Expand Down
15 changes: 13 additions & 2 deletions smsdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def all(self) -> t_.Dict[str, t_.Callable[..., t_.Any]]:
module_utility = ModuleUtility


def get_url(protocol: str, tenant: str, site_domain: str) -> str:
def get_url(
protocol: str, tenant: str, site_domain: str, port: t_.Optional[int] = None
) -> str:
"""
Get the URL of the web address.
Expand All @@ -30,6 +32,15 @@ def get_url(protocol: str, tenant: str, site_domain: str) -> str:
:type tenant: :class:`string`
:param site_domain: The domain name of the URL.
:type site_domain: :class:`string`
:param port: The port number (defaults to None).
:type port: :Int
"""

return f"{protocol}://{tenant}.{site_domain}"
url = ""

if port is not None:
url = f"{protocol}://{tenant}.{site_domain}:{port}"
else:
url = f"{protocol}://{tenant}.{site_domain}"

return url
Loading

0 comments on commit 5b11d26

Please sign in to comment.