-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(tensor): fix most commonly used case for slice
Signed-off-by: weiwee <wbwmat@gmail.com>
- Loading branch information
Showing
7 changed files
with
90 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,63 @@ | ||
import torch | ||
from fate.arch.tensor import _custom_ops | ||
|
||
from ._tensor import DTensor, implements | ||
|
||
|
||
@implements(_custom_ops.slice_f) | ||
def slice_f(input: DTensor, key): | ||
if isinstance(key, list): | ||
partition_keys = [[] for _ in storage.d_axis.partitions] | ||
agg = 0 | ||
i = 0 | ||
j = 0 | ||
while j < len(key) and i < len(storage.d_axis.partitions): | ||
if key[j] >= agg and key[j] < agg + storage.d_axis.partitions[i]: | ||
partition_keys[i].append(key[j] - agg) | ||
j += 1 | ||
# 1: int slice key means slice 0 dimention | ||
if isinstance(key, int): | ||
if 0 <= key < input.shape[0]: | ||
# 1.1: slice output in one of shardings | ||
if input.shardings.shapes.axis == 0: | ||
return input.shardings.map_reduce_shard_with_stride( | ||
stride_mapper_func=lambda stride, s: [s[key - stride]] | ||
if stride <= key < stride + s.shape[0] | ||
else [], | ||
reducer_func=lambda x, y: [*x, *y], | ||
)[0] | ||
# 1.2: slice output is distributed | ||
else: | ||
agg += storage.d_axis.partitions[i] | ||
i += 1 | ||
if j != len(key): | ||
raise ValueError(f"out of bound: {key}") | ||
|
||
def mapper(ind, s): | ||
return (ind, storage_ops.slice(s, partition_keys[ind])) | ||
|
||
blocks = storage.blocks.map(mapper) | ||
size = (len(key), *storage.shape.size[1:]) | ||
d_axis = DAxis(axis=storage.d_axis.axis, partitions=[len(p) for p in partition_keys]) | ||
|
||
return DStorage( | ||
blocks, | ||
shape=Shape(size, d_axis), | ||
dtype=storage.dtype, | ||
device=storage.device, | ||
transposed=storage.transposed, | ||
) | ||
else: | ||
raise NotImplementedError(f"key {key}") | ||
return DTensor( | ||
input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,))) | ||
) | ||
|
||
else: | ||
raise IndexError(f"index {key} is out of bounds for dimension 0 with size {input.shape[0]}") | ||
|
||
# 2: list slice key | ||
if isinstance(key, list): | ||
for k in key: | ||
if k < 0 or k >= input.shape[0]: | ||
raise IndexError(f"index {k} is out of bounds for dimension 0 with size {input.shape[0]}") | ||
|
||
if input.shardings.shapes.axis == 0: | ||
outputs = input.shardings.map_reduce_shard_with_stride( | ||
stride_mapper_func=lambda stride, s: [ | ||
(i, s[k - stride]) for i, k in enumerate(key) if stride <= k < stride + s.shape[0] | ||
], | ||
reducer_func=lambda x, y: [*x, *y], | ||
) | ||
return torch.cat([v for _, v in sorted(outputs)]) | ||
else: | ||
return DTensor(input.shardings.map_shard(lambda s: s[key], shapes=input.shardings.shapes.squeeze((0,)))) | ||
|
||
# 3: slice key | ||
if isinstance(key, slice): | ||
start, stop, step = key.indices(input.shape[0]) | ||
indices = list(range(start, stop, step)) | ||
return slice_f(input, indices) | ||
|
||
# 4: tuple key for multi-dimensional slicing | ||
if isinstance(key, tuple): | ||
raise NotImplementedError("tuple key {key}") | ||
# result = input | ||
# for dim, k in enumerate(key): | ||
# if isinstance(k, (int, list, slice)): | ||
# ... | ||
# else: | ||
# raise NotImplementedError(f"slice_f on {key}") | ||
# return result | ||
|
||
raise NotImplementedError(f"slice_f on {key}") |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from fate.arch.tensor import _custom_ops | ||
|
||
from ._tensor import DTensor, implements | ||
|
||
|
||
@implements(_custom_ops.to_local_f) | ||
def to_local_f(input: DTensor): | ||
return input.shardings.merge() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters