Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
isinyaaa committed Jun 21, 2024
1 parent 7760d27 commit dddba9a
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions test/robot/ModelRegistry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import model_registry as mr
from model_registry.core import ModelRegistryAPIClient
from model_registry.types import ModelArtifact, ModelVersion, RegisteredModel
from robot.libraries.BuiltIn import BuiltIn

Expand All @@ -8,35 +8,40 @@ def write_to_console(s):
BuiltIn().log_to_console(s)


class ModelRegistry(mr.core.ModelRegistryAPIClient):
def __init__(self, host: str = "localhost", port: int = 9090):
super().__init__(mr.store.MLMDStore.from_config(host, port))
class ModelRegistry:
def __init__(self, host: str = "http://localhost", port: int = 9090):
self.api = ModelRegistryAPIClient.insecure_connection(host, port)

def upsert_registered_model(self, registered_model) -> str:
p = RegisteredModel("")
for key, value in registered_model.items():
setattr(p, key, value)
return super().upsert_registered_model(p)
async def upsert_registered_model(self, registered_model: dict) -> str:
return (
await self.api.upsert_registered_model(RegisteredModel(**registered_model))
).id

def upsert_model_version(self, model_version, registered_model_id: str) -> str:
async def upsert_model_version(
self, model_version, registered_model_id: str
) -> str:
write_to_console(model_version)
p = ModelVersion("", "", "")
for key, value in model_version.items():
setattr(p, key, value)
p = ModelVersion(**model_version)
write_to_console(p)
return super().upsert_model_version(p, registered_model_id)
return (await self.api.upsert_model_version(p, registered_model_id)).id

def upsert_model_artifact(self, model_artifact, model_version_id: str) -> str:
async def upsert_model_artifact(
self, model_artifact: dict, model_version_id: str
) -> str:
write_to_console(model_artifact)
p = ModelArtifact("", "")
for key, value in model_artifact.items():
setattr(p, key, value)
p = ModelArtifact(**model_artifact)
write_to_console(p)
return super().upsert_model_artifact(p, model_version_id)
return (await self.api.upsert_model_artifact(p, model_version_id)).id


async def test():
demo_instance = ModelRegistry()
await demo_instance.upsert_registered_model({"name": "testing123"})
await demo_instance.upsert_model_version({"name": "v1"}, None)


# Used only for quick smoke tests
if __name__ == "__main__":
demo_instance = ModelRegistry()
demo_instance.upsert_registered_model({"name": "testing123"})
demo_instance.upsert_model_version({"name": "v1"}, None)
import asyncio

asyncio.run(test())

0 comments on commit dddba9a

Please sign in to comment.