From 4f8e4c9f5106058ac5f51df79cab671992e927b8 Mon Sep 17 00:00:00 2001 From: teo Date: Fri, 11 Oct 2024 11:08:53 +0300 Subject: [PATCH 1/6] added admin methods --- .../src/syft/service/sync/sync_service.py | 35 +++++++++- .../tests/syft/service/sync/get_set_object.py | 70 +++++++++++++++++++ 2 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 packages/syft/tests/syft/service/sync/get_set_object.py diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 19bed044eb4..8ad9ca7ae5e 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -32,6 +32,7 @@ from ..service import TYPE_TO_SERVICE from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL +from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from .sync_stash import SyncStash from .sync_state import SyncState @@ -39,10 +40,14 @@ def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash: - if isinstance(item, ActionObject): + return get_store_by_type(context=context, obj_type=type(item)) + + +def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash: + if issubclass(obj_type, ActionObject): service = context.server.services.action # type: ignore return service.stash # type: ignore - service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore + service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore return service.stash @@ -450,3 +455,29 @@ def build_current_state( ) def _get_state(self, context: AuthedServiceContext) -> SyncState: return self.build_current_state(context).unwrap() + + @service_method( + path="sync._get_object", + name="_get_object", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def _get_object( + self, context: AuthedServiceContext, uid: UID, object_type: type + ) -> Any: + return ( + get_store_by_type(context, object_type) + .get_by_uid(credentials=context.credentials, uid=uid) + .unwrap() + ) + + @service_method( + path="sync._update_object", + name="_update_object", + roles=ADMIN_ROLE_LEVEL, + ) + def _update_object(self, context: AuthedServiceContext, object: Any) -> Any: + return ( + get_store(context, object) + .update(credentials=context.credentials, obj=object) + .unwrap() + ) diff --git a/packages/syft/tests/syft/service/sync/get_set_object.py b/packages/syft/tests/syft/service/sync/get_set_object.py new file mode 100644 index 00000000000..884de15dc5d --- /dev/null +++ b/packages/syft/tests/syft/service/sync/get_set_object.py @@ -0,0 +1,70 @@ +# third party +import numpy as np +import pytest + +# syft absolute +import syft +import syft as sy +from syft.client.datasite_client import DatasiteClient +from syft.client.sync_decision import SyncDecision +from syft.client.syncing import compare_clients +from syft.client.syncing import resolve +from syft.server.worker import Worker +from syft.service.action.action_object import ActionObject +from syft.service.code.user_code import ApprovalDecision +from syft.service.code.user_code import UserCodeStatus +from syft.service.dataset.dataset import Dataset +from syft.service.job.job_stash import Job +from syft.service.request.request import RequestStatus +from syft.service.response import SyftSuccess +from syft.service.sync.resolve_widget import ResolveWidget +from syft.service.user.user import User, UserView +from syft.types.errors import SyftException + + +def get_ds_client(client: DatasiteClient) -> DatasiteClient: + client.register( + name="a", + email="a@a.com", + password="asdf", + password_verify="asdf", + ) + return client.login(email="a@a.com", password="asdf") + + +def test_get_set_object(high_worker): + high_client: DatasiteClient = high_worker.root_client + _ = get_ds_client(high_client) + root_datasite_client = high_worker.root_client + dataset = sy.Dataset( + name="local_test", + asset_list=[ + sy.Asset( + name="local_test", + data=[1, 2, 3], + mock=[1, 1, 1], + ) + ], + ) + root_datasite_client.upload_dataset(dataset) + dataset = root_datasite_client.datasets[0] + + other_dataset = high_client.api.services.sync._get_object(uid=dataset.id, object_type=Dataset) + other_dataset.server_uid = dataset.server_uid + assert dataset == other_dataset + other_dataset.name = "new_name" + updated_dataset = high_client.api.services.sync._update_object( + object=other_dataset + ) + assert updated_dataset.name == "new_name" + + asset = root_datasite_client.datasets[0].assets[0] + source_ao = high_client.api.services.action.get(uid=asset.action_id) + ao = high_client.api.services.sync._get_object( + uid=asset.action_id, object_type=ActionObject + ) + ao._set_obj_location_( + high_worker.id, + root_datasite_client.credentials, + ) + assert source_ao == ao From 062bb932b344dca1c20d8d605aa4e5b27043aa4a Mon Sep 17 00:00:00 2001 From: teo Date: Fri, 11 Oct 2024 13:30:34 +0300 Subject: [PATCH 2/6] added refresh method --- .../syft/src/syft/service/request/request.py | 1 + packages/syft/src/syft/types/syft_object.py | 7 +++++ ...t_set_object.py => get_set_object_test.py} | 31 +++++-------------- 3 files changed, 16 insertions(+), 23 deletions(-) rename packages/syft/tests/syft/service/sync/{get_set_object.py => get_set_object_test.py} (63%) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 7c5495a756e..ab07c7380f1 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -585,6 +585,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat # which tries to send an email to the admin and ends up here pass # lets keep going + self.refresh() if len(self.history) == 0: return RequestStatus.PENDING diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index f6a4d3233cb..2c33e15fc4c 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -429,6 +429,13 @@ def make_id(cls, values: Any) -> Any: __table_coll_widths__: ClassVar[list[str] | None] = None __table_sort_attr__: ClassVar[str | None] = None + def refresh(self) -> None: + api = self._get_api() + new_object = api.services.sync._get_object(uid=self.id, object_type=type(self)) + print(type(self), type(new_object)) + if type(new_object) == type(self): + self.__dict__.update(new_object.__dict__) + def __syft_get_funcs__(self) -> list[tuple[str, Signature]]: funcs = print_type_cache[type(self)] if len(funcs) > 0: diff --git a/packages/syft/tests/syft/service/sync/get_set_object.py b/packages/syft/tests/syft/service/sync/get_set_object_test.py similarity index 63% rename from packages/syft/tests/syft/service/sync/get_set_object.py rename to packages/syft/tests/syft/service/sync/get_set_object_test.py index 884de15dc5d..9f92d4812b5 100644 --- a/packages/syft/tests/syft/service/sync/get_set_object.py +++ b/packages/syft/tests/syft/service/sync/get_set_object_test.py @@ -1,25 +1,10 @@ # third party -import numpy as np -import pytest # syft absolute -import syft import syft as sy from syft.client.datasite_client import DatasiteClient -from syft.client.sync_decision import SyncDecision -from syft.client.syncing import compare_clients -from syft.client.syncing import resolve -from syft.server.worker import Worker from syft.service.action.action_object import ActionObject -from syft.service.code.user_code import ApprovalDecision -from syft.service.code.user_code import UserCodeStatus from syft.service.dataset.dataset import Dataset -from syft.service.job.job_stash import Job -from syft.service.request.request import RequestStatus -from syft.service.response import SyftSuccess -from syft.service.sync.resolve_widget import ResolveWidget -from syft.service.user.user import User, UserView -from syft.types.errors import SyftException def get_ds_client(client: DatasiteClient) -> DatasiteClient: @@ -48,14 +33,14 @@ def test_get_set_object(high_worker): ) root_datasite_client.upload_dataset(dataset) dataset = root_datasite_client.datasets[0] - - other_dataset = high_client.api.services.sync._get_object(uid=dataset.id, object_type=Dataset) + + other_dataset = high_client.api.services.sync._get_object( + uid=dataset.id, object_type=Dataset + ) other_dataset.server_uid = dataset.server_uid assert dataset == other_dataset other_dataset.name = "new_name" - updated_dataset = high_client.api.services.sync._update_object( - object=other_dataset - ) + updated_dataset = high_client.api.services.sync._update_object(object=other_dataset) assert updated_dataset.name == "new_name" asset = root_datasite_client.datasets[0].assets[0] @@ -64,7 +49,7 @@ def test_get_set_object(high_worker): uid=asset.action_id, object_type=ActionObject ) ao._set_obj_location_( - high_worker.id, - root_datasite_client.credentials, - ) + high_worker.id, + root_datasite_client.credentials, + ) assert source_ao == ao From 8f92951608f42e834e3621029cbac57d80d1d983 Mon Sep 17 00:00:00 2001 From: teo Date: Mon, 14 Oct 2024 09:42:14 +0300 Subject: [PATCH 3/6] fix refresh --- packages/syft/src/syft/types/syft_object.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 2c33e15fc4c..a1a1c3236b7 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -430,9 +430,11 @@ def make_id(cls, values: Any) -> Any: __table_sort_attr__: ClassVar[str | None] = None def refresh(self) -> None: - api = self._get_api() + try: + api = self._get_api() + except Exception as _: + return new_object = api.services.sync._get_object(uid=self.id, object_type=type(self)) - print(type(self), type(new_object)) if type(new_object) == type(self): self.__dict__.update(new_object.__dict__) From 1fef3eda3cdab56e5ea6e2e001b2faca2a22c599 Mon Sep 17 00:00:00 2001 From: teo Date: Mon, 14 Oct 2024 09:50:36 +0300 Subject: [PATCH 4/6] moved methods to migration service --- .../service/migration/migration_service.py | 31 +++++++++++++++++++ .../src/syft/service/sync/sync_service.py | 27 ---------------- .../syft/service/sync/get_set_object_test.py | 8 +++-- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 62788762acf..d4f217b9035 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,6 +1,7 @@ # stdlib from collections import defaultdict import logging +from typing import Any # syft absolute import syft @@ -16,6 +17,7 @@ from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry from ...types.twin_object import TwinObject +from ...types.uid import UID from ..action.action_object import Action from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission @@ -26,7 +28,10 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method +from ..sync.sync_service import get_store +from ..sync.sync_service import get_store_by_type from ..user.user_roles import ADMIN_ROLE_LEVEL +from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ..worker.utils import DEFAULT_WORKER_POOL_NAME from .object_migration_state import MigrationData from .object_migration_state import StoreMetadata @@ -493,3 +498,29 @@ def reset_and_restore( ) return SyftSuccess(message="Database reset successfully.") + + @service_method( + path="sync._get_object", + name="_get_object", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def _get_object( + self, context: AuthedServiceContext, uid: UID, object_type: type + ) -> Any: + return ( + get_store_by_type(context, object_type) + .get_by_uid(credentials=context.credentials, uid=uid) + .unwrap() + ) + + @service_method( + path="sync._update_object", + name="_update_object", + roles=ADMIN_ROLE_LEVEL, + ) + def _update_object(self, context: AuthedServiceContext, object: Any) -> Any: + return ( + get_store(context, object) + .update(credentials=context.credentials, obj=object) + .unwrap() + ) diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 8ad9ca7ae5e..b6cc955ac4f 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -32,7 +32,6 @@ from ..service import TYPE_TO_SERVICE from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from .sync_stash import SyncStash from .sync_state import SyncState @@ -455,29 +454,3 @@ def build_current_state( ) def _get_state(self, context: AuthedServiceContext) -> SyncState: return self.build_current_state(context).unwrap() - - @service_method( - path="sync._get_object", - name="_get_object", - roles=DATA_SCIENTIST_ROLE_LEVEL, - ) - def _get_object( - self, context: AuthedServiceContext, uid: UID, object_type: type - ) -> Any: - return ( - get_store_by_type(context, object_type) - .get_by_uid(credentials=context.credentials, uid=uid) - .unwrap() - ) - - @service_method( - path="sync._update_object", - name="_update_object", - roles=ADMIN_ROLE_LEVEL, - ) - def _update_object(self, context: AuthedServiceContext, object: Any) -> Any: - return ( - get_store(context, object) - .update(credentials=context.credentials, obj=object) - .unwrap() - ) diff --git a/packages/syft/tests/syft/service/sync/get_set_object_test.py b/packages/syft/tests/syft/service/sync/get_set_object_test.py index 9f92d4812b5..e6681dc621f 100644 --- a/packages/syft/tests/syft/service/sync/get_set_object_test.py +++ b/packages/syft/tests/syft/service/sync/get_set_object_test.py @@ -34,18 +34,20 @@ def test_get_set_object(high_worker): root_datasite_client.upload_dataset(dataset) dataset = root_datasite_client.datasets[0] - other_dataset = high_client.api.services.sync._get_object( + other_dataset = high_client.api.services.migration._get_object( uid=dataset.id, object_type=Dataset ) other_dataset.server_uid = dataset.server_uid assert dataset == other_dataset other_dataset.name = "new_name" - updated_dataset = high_client.api.services.sync._update_object(object=other_dataset) + updated_dataset = high_client.api.services.migration._update_object( + object=other_dataset + ) assert updated_dataset.name == "new_name" asset = root_datasite_client.datasets[0].assets[0] source_ao = high_client.api.services.action.get(uid=asset.action_id) - ao = high_client.api.services.sync._get_object( + ao = high_client.api.services.migration._get_object( uid=asset.action_id, object_type=ActionObject ) ao._set_obj_location_( From 331881ecf181662025fed22b5b8841cd184b644c Mon Sep 17 00:00:00 2001 From: teo Date: Mon, 14 Oct 2024 10:06:50 +0300 Subject: [PATCH 5/6] fix test --- packages/syft/src/syft/service/migration/migration_service.py | 4 ++-- packages/syft/src/syft/types/syft_object.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index d4f217b9035..4a346ebde9d 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -500,7 +500,7 @@ def reset_and_restore( return SyftSuccess(message="Database reset successfully.") @service_method( - path="sync._get_object", + path="migration._get_object", name="_get_object", roles=DATA_SCIENTIST_ROLE_LEVEL, ) @@ -514,7 +514,7 @@ def _get_object( ) @service_method( - path="sync._update_object", + path="migration._update_object", name="_update_object", roles=ADMIN_ROLE_LEVEL, ) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index a1a1c3236b7..1e93d377460 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -434,7 +434,9 @@ def refresh(self) -> None: api = self._get_api() except Exception as _: return - new_object = api.services.sync._get_object(uid=self.id, object_type=type(self)) + new_object = api.services.migration._get_object( + uid=self.id, object_type=type(self) + ) if type(new_object) == type(self): self.__dict__.update(new_object.__dict__) From 142a6ea7e857dc3b79d9e445e71f992333197caf Mon Sep 17 00:00:00 2001 From: teo Date: Mon, 14 Oct 2024 11:46:51 +0300 Subject: [PATCH 6/6] fix test --- packages/syft/src/syft/types/syft_object.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 1e93d377460..7b30ffaa562 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -432,13 +432,13 @@ def make_id(cls, values: Any) -> Any: def refresh(self) -> None: try: api = self._get_api() + new_object = api.services.migration._get_object( + uid=self.id, object_type=type(self) + ) + if type(new_object) == type(self): + self.__dict__.update(new_object.__dict__) except Exception as _: return - new_object = api.services.migration._get_object( - uid=self.id, object_type=type(self) - ) - if type(new_object) == type(self): - self.__dict__.update(new_object.__dict__) def __syft_get_funcs__(self) -> list[tuple[str, Signature]]: funcs = print_type_cache[type(self)]