Skip to content

Commit

Permalink
Merge branch 'dev' into shubham/notebooks-with-state
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham3121 authored Sep 27, 2024
2 parents 17327d8 + 53bb398 commit e429570
Show file tree
Hide file tree
Showing 33 changed files with 60 additions and 1,796 deletions.
2 changes: 0 additions & 2 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@
from .service.user.roles import Roles as roles
from .service.user.user_service import UserService
from .stable_version import LATEST_STABLE_SYFT
from .store.mongo_document_store import MongoStoreConfig
from .store.sqlite_document_store import SQLiteStoreConfig
from .types.errors import SyftException
from .types.errors import raises
from .types.result import as_result
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# syft absolute
from syft.types.result import Err
from syft.types.result import Ok
from syft.util.util import get_dev_mode

# relative
from .. import __version__
Expand All @@ -29,7 +30,6 @@
from ..types.errors import SyftException
from ..types.syft_object import SyftBaseObject
from ..types.syft_object_registry import SyftObjectRegistry
from ..util.util import get_dev_mode

PROTOCOL_STATE_FILENAME = "protocol_version.json"
PROTOCOL_TYPE = str | int
Expand Down
6 changes: 1 addition & 5 deletions packages/syft/src/syft/service/code/status_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ...serde.serializable import serializable
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.uid import UID
Expand All @@ -24,10 +23,7 @@

@serializable(canonical_name="StatusSQLStash", version=1)
class StatusStash(ObjectStash[UserCodeStatusCollection]):
settings: PartitionSettings = PartitionSettings(
name=UserCodeStatusCollection.__canonical_name__,
object_type=UserCodeStatusCollection,
)
pass


class CodeStatusUpdate(PartialSyftObject):
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/code/user_code_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
Expand All @@ -11,10 +10,6 @@

@serializable(canonical_name="UserCodeSQLStash", version=1)
class UserCodeStash(ObjectStash[UserCode]):
settings: PartitionSettings = PartitionSettings(
name=UserCode.__canonical_name__, object_type=UserCode
)

@as_result(StashException, NotFoundException)
def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode:
return self.get_one(
Expand Down
3 changes: 0 additions & 3 deletions packages/syft/src/syft/service/data_subject/data_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
Expand All @@ -17,8 +16,6 @@
from ...types.uid import UID
from ...util.markdown import as_markdown_python_code

NamePartitionKey = PartitionKey(key="name", type_=str)


@serializable()
class DataSubject(SyftObject):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject

ParentPartitionKey = PartitionKey(key="parent", type_=str)
ChildPartitionKey = PartitionKey(key="child", type_=str)


@serializable()
class DataSubjectMemberRelationship(SyftObject):
Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
from ...types.errors import SyftException
Expand Down Expand Up @@ -45,7 +44,6 @@
from ..response import SyftSuccess
from ..response import SyftWarning

NamePartitionKey = PartitionKey(key="name", type_=str)
logger = logging.getLogger(__name__)


Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ...service.context import AuthedServiceContext
from ...service.worker.worker_pool import SyftWorker
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.datetime import DateTime
Expand Down Expand Up @@ -736,10 +735,6 @@ def from_job(

@serializable(canonical_name="JobStashSQL", version=1)
class JobStash(ObjectStash[Job]):
settings: PartitionSettings = PartitionSettings(
name=Job.__canonical_name__, object_type=Job
)

@as_result(StashException)
def set_result(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import CreateBlobStorageEntry
Expand Down Expand Up @@ -64,9 +63,6 @@ def supported_versions(self) -> list:
return SyftObjectRegistry.get_versions(self.canonical_name)


KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str)


@serializable(canonical_name="SyftMigrationStateSQLStash", version=1)
class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]):
@as_result(SyftException, NotFoundException)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ def email_body(notification: "Notification", context: AuthedServiceContext) -> s
deny_reason_or_err = request_obj.get_deny_reason(context=context)
if deny_reason_or_err.is_err():
deny_reason = None
deny_reason = deny_reason_or_err.unwrap()
else:
deny_reason = deny_reason_or_err.unwrap()

if not isinstance(deny_reason, str) or not len(deny_reason):
deny_reason = (
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/notifier/notifier_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
Expand All @@ -17,10 +16,6 @@
@instrument
@serializable(canonical_name="NotifierSQLStash", version=1)
class NotifierStash(ObjectStash[NotifierSettings]):
settings: PartitionSettings = PartitionSettings(
name=NotifierSettings.__canonical_name__, object_type=NotifierSettings
)

@as_result(StashException, NotFoundException)
def get(self, credentials: SyftVerifyKey) -> NotifierSettings:
"""Get Settings"""
Expand Down
6 changes: 0 additions & 6 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand All @@ -26,11 +25,6 @@
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL

CreatedAtPartitionKey = PartitionKey(key="created_at", type_=DateTime)
UserCodeIdPartitionKey = PartitionKey(key="user_code_id", type_=UID)
JobIdPartitionKey = PartitionKey(key="job_id", type_=UID)
OutputPolicyIdPartitionKey = PartitionKey(key="output_policy_id", type_=UID)


@serializable()
class ExecutionOutput(SyncableSyftObject):
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ...serde.recursive_primitives import recursive_serde_register_type
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.datetime import DateTime
Expand Down Expand Up @@ -84,10 +83,6 @@ class OutputPolicyValidEnum(Enum):

DEFAULT_USER_POLICY_VERSION = 1

PolicyUserVerifyKeyPartitionKey = PartitionKey(
key="user_verify_key", type_=SyftVerifyKey
)

PyCodeObject = Any


Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/queue/base_queue.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# stdlib
from typing import Any
from typing import ClassVar
from typing import TYPE_CHECKING

# relative
from ...serde.serializable import serializable
from ...service.context import AuthedServiceContext
from ...store.document_store import NewBaseStash
from ...types.uid import UID
from ..response import SyftSuccess
from ..worker.worker_stash import WorkerStash

if TYPE_CHECKING:
# relative
from .queue_stash import QueueStash


@serializable(canonical_name="QueueClientConfig", version=1)
class QueueClientConfig:
Expand Down Expand Up @@ -105,7 +109,7 @@ def create_consumer(
def create_producer(
self,
queue_name: str,
queue_stash: type[NewBaseStash],
queue_stash: "QueueStash",
context: AuthedServiceContext,
worker_stash: WorkerStash,
) -> QueueProducer:
Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from threading import Thread
import time
from typing import Any
from typing import TYPE_CHECKING
from typing import cast

# third party
Expand All @@ -17,7 +18,6 @@
from ...server.credentials import SyftVerifyKey
from ...server.worker_settings import WorkerSettings
from ...service.context import AuthedServiceContext
from ...store.document_store import NewBaseStash
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.errors import SyftException
Expand All @@ -39,6 +39,10 @@
from .queue_stash import QueueItem
from .queue_stash import Status

if TYPE_CHECKING:
# relative
from .queue_stash import QueueStash

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -129,7 +133,7 @@ def create_consumer(
def create_producer(
self,
queue_name: str,
queue_stash: type[NewBaseStash],
queue_stash: "QueueStash",
context: AuthedServiceContext,
worker_stash: WorkerStash,
) -> QueueProducer:
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/queue/queue_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...server.worker_settings import WorkerSettings
from ...server.worker_settings import WorkerSettingsV1
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
Expand All @@ -35,10 +34,6 @@ class Status(str, Enum):
INTERRUPTED = "interrupted"


StatusPartitionKey = PartitionKey(key="status", type_=Status)
_WorkerPoolPartitionKey = PartitionKey(key="worker_pool", type_=LinkedObject)


@serializable()
class QueueItemV1(SyftObject):
__canonical_name__ = "QueueItem"
Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from ..serde.signature import signature_remove_self
from ..server.credentials import SyftVerifyKey
from ..store.db.stash import ObjectStash
from ..store.document_store import DocumentStore
from ..store.linked_obj import LinkedObject
from ..types.errors import SyftException
from ..types.result import as_result
Expand Down Expand Up @@ -71,7 +70,6 @@
class AbstractService:
server: AbstractServer
server_uid: UID
store_type: type = DocumentStore
stash: ObjectStash

@as_result(SyftException)
Expand Down
8 changes: 5 additions & 3 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ...serde.serializable import serializable
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import NewBaseStash
from ...store.document_store_errors import NotFoundException
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand Down Expand Up @@ -108,9 +107,10 @@ def transform_item(
self.set_obj_ids(context, item)
return item

@as_result(ValueError)
def get_stash_for_item(
self, context: AuthedServiceContext, item: SyftObject
) -> NewBaseStash:
) -> ObjectStash:
services = list(context.server.service_path_map.values()) # type: ignore

all_stashes = {}
Expand All @@ -119,6 +119,8 @@ def get_stash_for_item(
all_stashes[_stash.object_type] = _stash

stash = all_stashes.get(type(item), None)
if stash is None:
raise ValueError(f"Could not find stash for {type(item)}")
return stash

def add_permissions_for_item(
Expand Down Expand Up @@ -148,7 +150,7 @@ def add_storage_permissions_for_item(
def set_object(
self, context: AuthedServiceContext, item: SyncableSyftObject
) -> SyftObject:
stash = self.get_stash_for_item(context, item)
stash = self.get_stash_for_item(context, item).unwrap()
creds = context.credentials

obj = None
Expand Down
6 changes: 0 additions & 6 deletions packages/syft/src/syft/service/sync/sync_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import StashException
from ...types.result import as_result
from .sync_state import SyncState


@serializable(canonical_name="SyncStash", version=1)
class SyncStash(ObjectStash[SyncState]):
settings: PartitionSettings = PartitionSettings(
name=SyncState.__canonical_name__,
object_type=SyncState,
)

def __init__(self, store: DBManager) -> None:
super().__init__(store)
self.last_state: SyncState | None = None
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/worker/worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.document_store import SyftSuccess
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.uid import UID
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import AuthedServiceContext
from ..service import service_method
Expand Down
Loading

0 comments on commit e429570

Please sign in to comment.