Skip to content

Commit

Permalink
py: provide API builders for secure and insecure connections
Browse files Browse the repository at this point in the history
Signed-off-by: Isabella do Amaral <idoamara@redhat.com>
  • Loading branch information
isinyaaa committed May 7, 2024
1 parent dae1626 commit 1f74032
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 100 deletions.
2 changes: 1 addition & 1 deletion clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from model_registry import ModelRegistry

registry = ModelRegistry("server-address", author="Ada Lovelace") # Defaults to a secure connection via port 443

# registry = ModelRegistry("server-address", author="Ada Lovelace", port=1234) # To use MR without TLS
# registry = ModelRegistry("server-address", 1234, author="Ada Lovelace", is_secure=False) # To use MR without TLS

model = registry.register_model(
"my-model", # model name
Expand Down
35 changes: 33 additions & 2 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import os
from pathlib import Path
from typing import get_args
from warnings import warn

Expand All @@ -20,6 +22,7 @@ def __init__(
port: int = 443,
*,
author: str,
is_secure: bool = True,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
):
Expand All @@ -31,12 +34,40 @@ def __init__(
Keyword Args:
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to contents of path on envvar CERT.
"""
# TODO: get args from env
# TODO: get remaining args from env
self._author = author
self._api = ModelRegistryAPIClient(server_address, port, user_token, custom_ca)

if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

if is_secure:
root_ca = None
if not custom_ca:
if ca_path := os.getenv("CERT"):
root_ca = Path(ca_path).read_bytes()
# client might have a default CA setup
else:
root_ca = custom_ca

self._api = ModelRegistryAPIClient.secure_connection(
server_address, port, user_token, root_ca
)
elif custom_ca:
msg = "Custom CA provided without secure connection"
raise StoreException(msg)
else:
self._api = ModelRegistryAPIClient.insecure_connection(
server_address, port, user_token
)

def _register_model(self, name: str) -> RegisteredModel:
if rm := self._api.get_registered_model_by_params(name):
Expand Down
123 changes: 61 additions & 62 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from __future__ import annotations

import os
from pathlib import Path
from warnings import warn
from dataclasses import dataclass

import grpc

Expand All @@ -16,55 +14,58 @@
from .utils import header_adder_interceptor


@dataclass
class ModelRegistryAPIClient:
"""Model registry API."""

def __init__(
self,
store: MLMDStore

@classmethod
def secure_connection(
cls,
server_address: str,
port: int = 443,
user_token: bytes | None = None,
custom_ca: bytes | None = None,
):
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port. Defaults to 443.
user_token: The PEM-encoded user token as a byte string. Defaults to envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to envvar CERT.
user_token: The PEM-encoded user token as a byte string.
custom_ca: The PEM-encoded root certificates as a byte string. Defaults to GRPC_DEFAULT_SSL_ROOTS_FILE_PATH, then system default.
"""
if not user_token:
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
user_token = Path(sa_token).read_bytes()
else:
warn("User access token is missing", stacklevel=2)

if port == 443:
if not custom_ca:
ca_cert = os.environ.get("CERT")
if not ca_cert:
msg = "CA certificate must be provided"
raise StoreException(msg)
root_certs = Path(ca_cert).read_bytes()
else:
root_certs = custom_ca
chan_creds = grpc.ssl_channel_credentials(root_certs)

if user_token:
call_creds = grpc.access_token_call_credentials(user_token)
chan_creds = grpc.composite_channel_credentials(
chan_creds,
call_creds,
)

chan = grpc.secure_channel(
f"{server_address}:443",
chan_creds,
)
elif user_token:
msg = "user token must be provided for secure connection"
raise StoreException(msg)

chan = grpc.secure_channel(
f"{server_address}:{port}",
grpc.composite_channel_credentials(
# custom_ca = None will get the default root certificates
grpc.ssl_channel_credentials(custom_ca),
grpc.access_token_call_credentials(user_token),
),
)

return cls(MLMDStore.from_channel(chan))

@classmethod
def insecure_connection(
cls,
server_address: str,
port: int,
user_token: bytes | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port.
user_token: The PEM-encoded user token as a byte string.
"""
if user_token:
chan = grpc.intercept_channel(
grpc.insecure_channel(f"{server_address}:{port}"),
# header key has to be lowercase
Expand All @@ -73,7 +74,7 @@ def __init__(
else:
chan = grpc.insecure_channel(f"{server_address}:{port}")

self._store = MLMDStore.from_channel(chan)
return cls(MLMDStore.from_channel(chan))

def _map(self, py_obj: ProtoBase) -> ProtoType:
"""Map a Python object to a proto object.
Expand All @@ -86,7 +87,7 @@ def _map(self, py_obj: ProtoBase) -> ProtoType:
Returns:
Proto object.
"""
type_id = self._store.get_type_id(
type_id = self.store.get_type_id(
py_obj.get_proto_type(), py_obj.get_proto_type_name()
)
return py_obj.map(type_id)
Expand All @@ -103,9 +104,9 @@ def upsert_registered_model(self, registered_model: RegisteredModel) -> str:
Returns:
ID of the registered model.
"""
id = self._store.put_context(self._map(registered_model))
id = self.store.put_context(self._map(registered_model))
new_py_rm = RegisteredModel.unmap(
self._store.get_context(RegisteredModel.get_proto_type_name(), id)
self.store.get_context(RegisteredModel.get_proto_type_name(), id)
)
id = str(id)
registered_model.id = id
Expand All @@ -124,7 +125,7 @@ def get_registered_model_by_id(self, id: str) -> RegisteredModel | None:
Returns:
Registered model.
"""
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(), id=int(id)
)
if proto_rm is not None:
Expand All @@ -150,7 +151,7 @@ def get_registered_model_by_params(
if name is None and external_id is None:
msg = "Either name or external_id must be provided"
raise StoreException(msg)
proto_rm = self._store.get_context(
proto_rm = self.store.get_context(
RegisteredModel.get_proto_type_name(),
name=name,
external_id=external_id,
Expand All @@ -172,7 +173,7 @@ def get_registered_models(
Registered models.
"""
mlmd_options = options.as_mlmd_list_options() if options else MLMDListOptions()
proto_rms = self._store.get_contexts(
proto_rms = self.store.get_contexts(
RegisteredModel.get_proto_type_name(), mlmd_options
)
return [RegisteredModel.unmap(proto_rm) for proto_rm in proto_rms]
Expand All @@ -194,10 +195,10 @@ def upsert_model_version(
"""
# this is not ideal but we need this info for the prefix
model_version._registered_model_id = registered_model_id
id = self._store.put_context(self._map(model_version))
self._store.put_context_parent(int(registered_model_id), id)
id = self.store.put_context(self._map(model_version))
self.store.put_context_parent(int(registered_model_id), id)
new_py_mv = ModelVersion.unmap(
self._store.get_context(ModelVersion.get_proto_type_name(), id)
self.store.get_context(ModelVersion.get_proto_type_name(), id)
)
id = str(id)
model_version.id = id
Expand All @@ -216,7 +217,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None:
Returns:
Model version.
"""
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), id=int(model_version_id)
)
if proto_mv is not None:
Expand All @@ -240,7 +241,7 @@ def get_model_versions(
mlmd_options.filter_query = f"parent_contexts_a.id = {registered_model_id}"
return [
ModelVersion.unmap(proto_mv)
for proto_mv in self._store.get_contexts(
for proto_mv in self.store.get_contexts(
ModelVersion.get_proto_type_name(), mlmd_options
)
]
Expand All @@ -267,7 +268,7 @@ def get_model_version_by_params(
StoreException: If neither external ID nor registered model ID and version is provided.
"""
if external_id is not None:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(), external_id=external_id
)
elif registered_model_id is None or version is None:
Expand All @@ -276,7 +277,7 @@ def get_model_version_by_params(
)
raise StoreException(msg)
else:
proto_mv = self._store.get_context(
proto_mv = self.store.get_context(
ModelVersion.get_proto_type_name(),
name=f"{registered_model_id}:{version}",
)
Expand Down Expand Up @@ -304,17 +305,17 @@ def upsert_model_artifact(
StoreException: If the model version already has a model artifact.
"""
mv_id = int(model_version_id)
if self._store.get_attributed_artifact(
if self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), mv_id
):
msg = f"Model version with ID {mv_id} already has a model artifact"
raise StoreException(msg)

model_artifact._model_version_id = model_version_id
id = self._store.put_artifact(self._map(model_artifact))
self._store.put_attribution(mv_id, id)
id = self.store.put_artifact(self._map(model_artifact))
self.store.put_attribution(mv_id, id)
new_py_ma = ModelArtifact.unmap(
self._store.get_artifact(ModelArtifact.get_proto_type_name(), id)
self.store.get_artifact(ModelArtifact.get_proto_type_name(), id)
)
id = str(id)
model_artifact.id = id
Expand All @@ -333,9 +334,7 @@ def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None:
Returns:
Model artifact.
"""
proto_ma = self._store.get_artifact(
ModelArtifact.get_proto_type_name(), int(id)
)
proto_ma = self.store.get_artifact(ModelArtifact.get_proto_type_name(), int(id))
if proto_ma is not None:
return ModelArtifact.unmap(proto_ma)

Expand All @@ -357,14 +356,14 @@ def get_model_artifact_by_params(
StoreException: If neither external ID nor model version ID is provided.
"""
if external_id:
proto_ma = self._store.get_artifact(
proto_ma = self.store.get_artifact(
ModelArtifact.get_proto_type_name(), external_id=external_id
)
elif not model_version_id:
msg = "Either model_version_id or external_id must be provided"
raise StoreException(msg)
else:
proto_ma = self._store.get_attributed_artifact(
proto_ma = self.store.get_attributed_artifact(
ModelArtifact.get_proto_type_name(), int(model_version_id)
)
if proto_ma is not None:
Expand All @@ -390,7 +389,7 @@ def get_model_artifacts(
if model_version_id is not None:
mlmd_options.filter_query = f"contexts_a.id = {model_version_id}"

proto_mas = self._store.get_artifacts(
proto_mas = self.store.get_artifacts(
ModelArtifact.get_proto_type_name(), mlmd_options
)
return [ModelArtifact.unmap(proto_ma) for proto_ma in proto_mas]
2 changes: 1 addition & 1 deletion clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def store_wrapper(plain_wrapper: MLMDStore) -> MLMDStore:
@pytest.fixture()
def mr_api(store_wrapper: MLMDStore) -> ModelRegistryAPIClient:
mr = object.__new__(ModelRegistryAPIClient)
mr._store = store_wrapper
mr.store = store_wrapper
return mr


Expand Down
9 changes: 9 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ def mr_client(mr_api: ModelRegistryAPIClient) -> ModelRegistry:
return mr


def test_secure_client():
os.environ["CERT"] = ""
os.environ["KF_PIPELINES_SA_TOKEN_PATH"] = ""
with pytest.raises(StoreException) as e:
ModelRegistry("anything", author="test_author")

assert "user token" in str(e.value).lower()


def test_register_new(mr_client: ModelRegistry):
name = "test_model"
version = "1.0.0"
Expand Down
Loading

0 comments on commit 1f74032

Please sign in to comment.