Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
teo-milea committed Oct 4, 2024
1 parent b4e3dd2 commit 8337b60
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 40 deletions.
8 changes: 4 additions & 4 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def handle_sync_batch(
obj_diff_batch: ObjectDiffBatch,
share_private_data: dict[UID, bool],
mockify: dict[UID, bool],
) -> SyftSuccess:
) -> tuple[SyftSuccess, SyftSuccess]:
# Infer SyncDecision
sync_direction = obj_diff_batch.sync_direction
if sync_direction is None:
Expand Down Expand Up @@ -226,9 +226,9 @@ def handle_sync_batch(
src_resolved_state.add_sync_instruction(sync_instruction)
# Apply empty state to source side to signal that we are done syncing
# We also add permissions for users from the low side to mark L0 request as approved
print(len(tgt_resolved_state.new_permissions), len(src_resolved_state.new_permissions))

return tgt_client.apply_state(tgt_resolved_state), src_client.apply_state(src_resolved_state)
return tgt_client.apply_state(tgt_resolved_state), src_client.apply_state(
src_resolved_state
)


def handle_ignore_batch(
Expand Down
31 changes: 19 additions & 12 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# stdlib
from typing import ClassVar, List
from typing import ClassVar

# third party
from pydantic import model_validator
Expand All @@ -17,7 +17,7 @@
from ...types.syncable_object import SyncableSyftObject
from ...types.uid import UID
from ..action.action_object import ActionObject
from ..action.action_permissions import ActionObjectPermission, ActionObjectREAD
from ..action.action_permissions import ActionObjectREAD
from ..context import AuthedServiceContext
from ..service import AbstractService
from ..service import TYPE_TO_SERVICE
Expand Down Expand Up @@ -314,16 +314,23 @@ def get_by_output_policy_id(
)
def get(self, context: AuthedServiceContext, id: UID) -> ExecutionOutput:
return self.stash.get_by_uid(context.credentials, id).unwrap()

@service_method(
path="output.set_permission",
name="set_permission",
roles=GUEST_ROLE_LEVEL,
)
def set_permission(self, context: AuthedServiceContext, uid, credentials) -> ExecutionOutput:
exec_output = self.get(context, uid)
permissions = [ActionObjectREAD(uid=_id.id, credentials=credentials) for _id in exec_output.output_id_list]
return context.server.services.action.stash.add_permissions(permissions).unwrap()

# @service_method(
# path="output.set_permission",
# name="set_permission",
# roles=GUEST_ROLE_LEVEL,
# )
# def set_permission(
# self, context: AuthedServiceContext, uid, credentials
# ) -> ExecutionOutput:
# exec_output = self.get(context, uid)
# permissions = [
# ActionObjectREAD(uid=_id.id, credentials=credentials)
# for _id in exec_output.output_id_list
# ]
# return context.server.services.action.stash.add_permissions(
# permissions
# ).unwrap()

@service_method(path="output.get_all", name="get_all", roles=GUEST_ROLE_LEVEL)
def get_all(self, context: AuthedServiceContext) -> list[ExecutionOutput]:
Expand Down
49 changes: 31 additions & 18 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import operator
import textwrap
from typing import Any, Dict
from typing import Any
from typing import ClassVar
from typing import Literal
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -1594,20 +1594,22 @@ def from_batch_decision(
)
else:
new_permissions_low_side = {
diff.obj_type:
[ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)]
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}
new_permissions_high_side = {
diff.obj_type:
[ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)]
diff.obj_type: [
ActionObjectPermission(
uid=diff.object_id,
permission=ActionPermission.READ,
credentials=share_to_user,
)
]
}

# storage permissions
Expand Down Expand Up @@ -1677,7 +1679,10 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
if sync_instruction.unignore:
self.unignored_batches.add(sync_instruction.batch_diff.root_id)

if diff.status == "SAME" and len(sync_instruction.new_permissions_highside) == 0:
if (
diff.status == "SAME"
and len(sync_instruction.new_permissions_highside) == 0
):
return

my_obj = diff.low_obj if self.alias == "low" else diff.high_obj
Expand Down Expand Up @@ -1708,18 +1713,26 @@ def add_sync_instruction(self, sync_instruction: SyncInstruction) -> None:
if self.alias == "low":
for obj_type in sync_instruction.new_permissions_lowside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(sync_instruction.new_permissions_lowside[obj_type])
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_lowside[obj_type]
)
else:
self.new_permissions[obj_type] = sync_instruction.new_permissions_lowside[obj_type]
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_lowside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_lowside
)
elif self.alias == "high":
for obj_type in sync_instruction.new_permissions_highside.keys():
if obj_type in self.new_permissions:
self.new_permissions[obj_type].extend(sync_instruction.new_permissions_highside[obj_type])
self.new_permissions[obj_type].extend(
sync_instruction.new_permissions_highside[obj_type]
)
else:
self.new_permissions[obj_type] = sync_instruction.new_permissions_highside[obj_type]
self.new_permissions[obj_type] = (
sync_instruction.new_permissions_highside[obj_type]
)
self.new_storage_permissions.extend(
sync_instruction.new_storage_permissions_highside
)
Expand Down
8 changes: 5 additions & 3 deletions packages/syft/src/syft/service/sync/resolve_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def deny_and_ignore(self, reason: str) -> None:
assert request is not None # nosec: B101
request.deny(reason)

def click_sync(self, *args: list, **kwargs: dict) -> SyftSuccess:
def click_sync(
self, *args: list, **kwargs: dict
) -> tuple[SyftSuccess, SyftSuccess]:
# relative
from ...client.syncing import handle_sync_batch

Expand All @@ -499,7 +501,7 @@ def click_sync(self, *args: list, **kwargs: dict) -> SyftSuccess:
public_message="The changes in this widget have already been synced."
)

res1, res2 = handle_sync_batch(
res1, res2 = handle_sync_batch(
obj_diff_batch=self.obj_diff_batch,
share_private_data=self.get_share_private_data_state(),
mockify=self.get_mockify_state(),
Expand Down Expand Up @@ -830,7 +832,7 @@ def on_paginate(self, index: int) -> None:
def build(self) -> widgets.VBox:
return widgets.VBox([self.table_output, self.paginated_widget.build()])

def click_sync(self, index: int) -> SyftSuccess:
def click_sync(self, index: int) -> tuple[SyftSuccess, SyftSuccess]:
return self.resolve_widgets[index].click_sync()

def click_share_all_private_data(self, index: int) -> None:
Expand Down
7 changes: 4 additions & 3 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,16 @@ def sync_items(
self.add_storage_permissions_for_item(
context, item, new_storage_permissions
)

# If we just want to add permissions without having an object
# This should happen only for the high side when we sync results but
# This should happen only for the high side when we sync results but
# we need to add permissions for the DS to properly show the status of the requests
for obj_type, permission_list in permissions.items():
if issubclass(obj_type, ActionObject):
store = context.server.services.action.stash
else:
store = context.server.get_service(TYPE_TO_SERVICE[obj_type]).stash
service = context.server.get_service(TYPE_TO_SERVICE[obj_type])
store = service.stash # type: ignore[assignment]
for permission in permission_list:
if permission.permission == ActionPermission.READ:
store.add_permission(permission)
Expand Down

0 comments on commit 8337b60

Please sign in to comment.