Skip to content

Commit

Permalink
fix(tensor): fix most commonly used case for slice
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Mar 15, 2023
1 parent 2f74300 commit 0c13006
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 36 deletions.
8 changes: 7 additions & 1 deletion python/fate/arch/tensor/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ def slice_f(input, arg):
else:
# torch tensor-like
if hasattr(input, "__torch_function__"):
return input.__torch_function__(slice_f, (type(input),), (input, arg), None)
out = input.__torch_function__(slice_f, (type(input),), (input, arg), None)
if out == NotImplemented:
raise NotImplementedError(f"slice_f: {input}")
return out

raise NotImplementedError(f"slice_f: {input}")


# hook custom ops to torch
torch.encrypt_f = encrypt_f
torch.decrypt_f = decrypt_f
torch.rmatmul_f = rmatmul_f
torch.to_local_f = to_local_f
torch.slice_f = slice_f
3 changes: 2 additions & 1 deletion python/fate/arch/tensor/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._op_binary import *
from ._op_slice import *
from ._ops_agg import *
from ._ops_binary import *
from ._ops_cipher import *
from ._ops_others import *
from ._ops_unary import *
Expand Down
8 changes: 4 additions & 4 deletions python/fate/arch/tensor/distributed/_op_matmul.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from fate.arch.storage import storage_ops
from fate.arch.storage._shape import DAxis, Shape
import torch

from .._storage import DStorage
from ._tensor import DTensor, implements


def matmul(a: DStorage, b: DStorage):
@implements(torch.matmul)
def matmul(a: DTensor, b: DStorage):
bc_shape_a = a.shape[:-2]
bc_shape_b = b.shape[:-2]
bs_shape = Shape.broadcast_shape([bc_shape_a, bc_shape_b], raise_exception=False)
Expand Down
85 changes: 55 additions & 30 deletions python/fate/arch/tensor/distributed/_op_slice.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
import torch
from fate.arch.tensor import _custom_ops

from ._tensor import DTensor, implements


@implements(_custom_ops.slice_f)
def slice_f(input: DTensor, key):
if isinstance(key, list):
partition_keys = [[] for _ in storage.d_axis.partitions]
agg = 0
i = 0
j = 0
while j < len(key) and i < len(storage.d_axis.partitions):
if key[j] >= agg and key[j] < agg + storage.d_axis.partitions[i]:
partition_keys[i].append(key[j] - agg)
j += 1
# 1: int slice key means slice 0 dimention
if isinstance(key, int):
if 0 <= key < input.shape[0]:
# 1.1: slice output in one of shardings
if input.shardings.shapes.axis == 0:
return input.shardings.map_reduce_shard_with_stride(
stride_mapper_func=lambda stride, s: [s[key - stride]]
if stride <= key < stride + s.shape[0]
else [],
reducer_func=lambda x, y: [*x, *y],
)[0]
# 1.2: slice output is distributed
else:
agg += storage.d_axis.partitions[i]
i += 1
if j != len(key):
raise ValueError(f"out of bound: {key}")

def mapper(ind, s):
return (ind, storage_ops.slice(s, partition_keys[ind]))

blocks = storage.blocks.map(mapper)
size = (len(key), *storage.shape.size[1:])
d_axis = DAxis(axis=storage.d_axis.axis, partitions=[len(p) for p in partition_keys])

return DStorage(
blocks,
shape=Shape(size, d_axis),
dtype=storage.dtype,
device=storage.device,
transposed=storage.transposed,
)
else:
raise NotImplementedError(f"key {key}")
return DTensor(
input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,)))
)

else:
raise IndexError(f"index {key} is out of bounds for dimension 0 with size {input.shape[0]}")

# 2: list slice key
if isinstance(key, list):
for k in key:
if k < 0 or k >= input.shape[0]:
raise IndexError(f"index {k} is out of bounds for dimension 0 with size {input.shape[0]}")

if input.shardings.shapes.axis == 0:
outputs = input.shardings.map_reduce_shard_with_stride(
stride_mapper_func=lambda stride, s: [
(i, s[k - stride]) for i, k in enumerate(key) if stride <= k < stride + s.shape[0]
],
reducer_func=lambda x, y: [*x, *y],
)
return torch.cat([v for _, v in sorted(outputs)])
else:
return DTensor(input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,))))

# 3: slice key
if isinstance(key, slice):
start, stop, step = key.indices(input.shape[0])
indices = list(range(start, stop, step))
return slice_f(input, indices)

# 4: tuple key for multi-dimensional slicing
if isinstance(key, tuple):
raise NotImplementedError("tuple key {key}")
# result = input
# for dim, k in enumerate(key):
# if isinstance(k, (int, list, slice)):
# ...
# else:
# raise NotImplementedError(f"slice_f on {key}")
# return result

raise NotImplementedError(f"slice_f on {key}")
8 changes: 8 additions & 0 deletions python/fate/arch/tensor/distributed/_ops_others.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from fate.arch.tensor import _custom_ops

from ._tensor import DTensor, implements


@implements(_custom_ops.to_local_f)
def to_local_f(input: DTensor):
return input.shardings.merge()
14 changes: 14 additions & 0 deletions python/fate/test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,17 @@ def test_binary_bc_tensor(ctx, op):

t2 = torch.rand((4, 5))
assert op(dt1, t2) == DTensor.from_sharding_list(ctx, [op(s, t2) for s in t1], num_partitions=3)


def test_slice(ctx):
t1 = [torch.rand((2, 3, 4, 5)) for _ in range(3)]
dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3)
assert torch.allclose(torch.slice_f(dt1, 3), t1[1][1])

dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3, axis=1)
assert torch.slice_f(dt1, 1) == DTensor.from_sharding_list(ctx, [s[1] for s in t1], num_partitions=3)

dt1 = DTensor.from_sharding_list(ctx, t1, num_partitions=3)
# assert torch.allclose(torch.slice_f(dt1, [3,1,2]), torch.cat(t1)[[3,1,2]])
print(torch.slice_f(dt1, [3, 1, 2]).shape)
print(torch.cat(t1)[[3, 1, 2]].shape)

0 comments on commit 0c13006

Please sign in to comment.