Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Aug 9, 2023
1 parent 70a2158 commit 14aa21b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/fate/arch/tensor/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def decrypt_f(tensor, decryptor):
# torch tensor-like
if hasattr(tensor, "__torch_function__"):
return tensor.__torch_function__(decrypt_f, (type(tensor),), (tensor, decryptor), None)
raise NotImplementedError("")
raise NotImplementedError(f"{type(tensor)}")


def decode_f(tensor, coder):
if hasattr(tensor, "__torch_function__"):
return tensor.__torch_function__(decode_f, (type(tensor),), (tensor, coder), None)
raise NotImplementedError("")
raise NotImplementedError(f"{type(tensor)}")


def rmatmul_f(input, other):
Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/phe/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def encrypt_encoded(input, encryptor):
return encryptor.encrypt_encoded(input)


@implements_encoded(_custom_ops.decrypt_encoded_f)
@implements(_custom_ops.decrypt_encoded_f)
def decrypt_encoded(input, decryptor):
return decryptor.decrypt_encoded(input)

Expand Down
7 changes: 7 additions & 0 deletions python/fate/test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def test_unary(ctx, t1_f32, t1_f32_sharding, op):
assert op(t1_f32) == DTensor.from_sharding_list(ctx, [op(s) for s in t1_f32_sharding], num_partitions=3)


def test_cipher(ctx, t1_f32):
kit = ctx.cipher.phe.setup({})
encryptor, decryptor = kit.get_tensor_encryptor(), kit.get_tensor_decryptor()
encrypted = encryptor.encrypt_tensor(t1_f32)
print(torch.to_local_f(decryptor.decrypt_tensor(encrypted)))


@pytest.mark.parametrize(
"op",
[torch.add, torch.sub, torch.mul, torch.div, torch.rsub],
Expand Down

0 comments on commit 14aa21b

Please sign in to comment.