From 19d1629ead78b5036f4c869038758d2f6e6fdd5b Mon Sep 17 00:00:00 2001 From: weiwee Date: Sun, 12 Mar 2023 18:55:55 -0800 Subject: [PATCH] fix(tensor): fix unary and binary ops for dtensor Signed-off-by: weiwee --- .../fate/arch/tensor/distributed/__init__.py | 1 + .../arch/tensor/distributed/_op_binary.py | 27 +-- .../arch/tensor/distributed/_ops_cipher.py | 16 +- .../arch/tensor/distributed/_ops_unary.py | 18 ++ .../fate/arch/tensor/distributed/_tensor.py | 198 +++++++++++------- python/fate/test/test_dtensor.py | 44 ++++ 6 files changed, 203 insertions(+), 101 deletions(-) create mode 100644 python/fate/arch/tensor/distributed/_ops_unary.py create mode 100644 python/fate/test/test_dtensor.py diff --git a/python/fate/arch/tensor/distributed/__init__.py b/python/fate/arch/tensor/distributed/__init__.py index 430a3a2d57..a8cfce6e0f 100644 --- a/python/fate/arch/tensor/distributed/__init__.py +++ b/python/fate/arch/tensor/distributed/__init__.py @@ -1,6 +1,7 @@ from ._op_binary import * from ._ops_agg import * from ._ops_cipher import * +from ._ops_unary import * from ._tensor import DTensor __all__ = ["DTensor"] diff --git a/python/fate/arch/tensor/distributed/_op_binary.py b/python/fate/arch/tensor/distributed/_op_binary.py index c79116f8cd..7825b618c9 100644 --- a/python/fate/arch/tensor/distributed/_op_binary.py +++ b/python/fate/arch/tensor/distributed/_op_binary.py @@ -13,6 +13,11 @@ def sub(input, other): return _binary(input, other, torch.sub) +@implements(torch.rsub) +def rsub(input, other): + return _binary(input, other, torch.rsub) + + @implements(torch.mul) def mul(input, other): return _binary(input, other, torch.mul) @@ -20,30 +25,26 @@ def mul(input, other): @implements(torch.div) def div(input, other): - return _binary(input, other, torch.div) + return _binary(input, other, torch.div, dtype_promote_to=torch.float32) -def _binary(input, other, op, swap=False): +def _binary(input, other, op, swap_operad=False, dtype_promote_to=None): # swap input and output if input is not DStroage if not isinstance(input, DTensor): - return _binary(op, other, input, swap=not swap) + return _binary(op, other, input, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to) # input and other both DStorage # TODO: validate if isinstance(other, DTensor): - if swap: - output_blocks = other.blocks.join(input.blocks, op) + if swap_operad: + return DTensor(other.shardings.join_shard(input.shardings, op, dtype_promote_to=dtype_promote_to)) else: - output_blocks = input.blocks.join(other.blocks, op) - output_dtype = torch.promote_types(input.dtype, other.dtype) - output_shape = torch.broadcast_shapes(input.shape, other.shape) - return DTensor(output_blocks, output_shape, input._d_axis, output_dtype, input._device) + return DTensor(input.shardings.join_shard(other.shardings, op, dtype_promote_to=dtype_promote_to)) # other is local tensor, broadcast to partitions # TODO: validate broadcast else: - if swap: - output_blocks = input.blocks.mapValues(lambda x: op(other, x)) + if swap_operad: + return DTensor(input.shardings.map_shard(lambda x: op(other, x, dtype_promote_to=dtype_promote_to))) else: - output_blocks = input.blocks.mapValues(lambda x: op(x, other)) - return DTensor(output_blocks, input.shape, input._d_axis, input._dtype, input.device) + return DTensor(input.shardings.map_shard(lambda x: op(x, other), dtype_promote_to=dtype_promote_to)) diff --git a/python/fate/arch/tensor/distributed/_ops_cipher.py b/python/fate/arch/tensor/distributed/_ops_cipher.py index 51cef479ea..a5d1673fed 100644 --- a/python/fate/arch/tensor/distributed/_ops_cipher.py +++ b/python/fate/arch/tensor/distributed/_ops_cipher.py @@ -5,21 +5,9 @@ @implements(_custom_ops.encrypt) def encrypt(input: DTensor, encryptor): - return DTensor( - input.blocks.mapValues(lambda x: _custom_ops.encrypt(x, encryptor)), - input.shape, - input.d_axis, - input.dtype, - input.device, - ) + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.encrypt(x, encryptor), input.dtype)) @implements(_custom_ops.decrypt) def decrypt(input: DTensor, decryptor): - return DTensor( - input.blocks.mapValues(lambda x: _custom_ops.decrypt(x, decryptor)), - input.shape, - input.d_axis, - input.dtype, - input.device, - ) + return DTensor(input.shardings.map_shard(lambda x: _custom_ops.decrypt(x, decryptor), input.dtype)) diff --git a/python/fate/arch/tensor/distributed/_ops_unary.py b/python/fate/arch/tensor/distributed/_ops_unary.py new file mode 100644 index 0000000000..bf191fdef0 --- /dev/null +++ b/python/fate/arch/tensor/distributed/_ops_unary.py @@ -0,0 +1,18 @@ +import torch + +from ._tensor import DTensor, implements + + +@implements(torch.exp) +def exp(input: DTensor): + return DTensor(input.shardings.map_shard(torch.exp, dtype_promote_to=torch.float32)) + + +@implements(torch.log) +def log(input: DTensor): + return DTensor(input.shardings.map_shard(torch.log, dtype_promote_to=torch.float32)) + + +@implements(torch.square) +def square(input: DTensor): + return DTensor(input.shardings.map_shard(torch.square)) diff --git a/python/fate/arch/tensor/distributed/_tensor.py b/python/fate/arch/tensor/distributed/_tensor.py index 35df3c9e13..8f238d6ac9 100644 --- a/python/fate/arch/tensor/distributed/_tensor.py +++ b/python/fate/arch/tensor/distributed/_tensor.py @@ -12,10 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools -from typing import List +import typing +from typing import List, Optional, cast import torch +from fate.arch.computing import CTableABC +from fate.arch.context import Context _HANDLED_FUNCTIONS = {} @@ -31,27 +35,7 @@ def decorator(func): return decorator -class DAxis: - def __init__(self, axis: int, partitions) -> None: - self.axis = axis - self.partitions = partitions - - def __str__(self) -> str: - return f"DAxis" - - class DTensor: - def __init__( - self, blocks, shape: torch.Size, d_axis: DAxis, dtype: torch.dtype, device: torch.device, transposed=False - ) -> None: - self.blocks = blocks - self._shape = shape - self._dtype = dtype - self._device = device - self._d_axis = d_axis - # TODO: fix me - self.transposed = transposed - @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: @@ -60,41 +44,52 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return NotImplemented return _HANDLED_FUNCTIONS[func](*args, **kwargs) - @property - def shape(self): - return self._shape + def __init__(self, shardings: "Shardings") -> None: + self.shardings = shardings @property - def d_axis(self) -> DAxis: - return self._d_axis + def shape(self): + return self.shardings.shape @property def dtype(self): - return self._dtype + return self.shardings.dtype @property def device(self): - return self._device - - def transpose(self) -> "DTensor": - return DTensor(self.blocks, self.shape.transpose(), self.dtype, self.device, not self.transposed) + return self.shardings.device def __eq__(self, __o: object) -> bool: - if isinstance(__o, DTensor) and self._dtype == __o.dtype and self._device == __o.device: - return torch.allclose(self.to_local(), __o.to_local()) - else: - return False + return isinstance(__o, DTensor) and self.shardings == __o.shardings def __str__(self) -> str: - return f"DStorage({self.device}, {self.dtype}, {self.shape})" - - def collect(self) -> List[torch.Tensor]: - return [pair[1] for pair in sorted(self.blocks.collect())] + return f"" def to_local(self) -> torch.Tensor: - storages = self.collect() + storages = [pair[1] for pair in sorted(self.blocks.collect())] return torch.cat(storages, self._d_axis.axis) + @classmethod + def from_sharding_table( + cls, + data: CTableABC, + shapes: Optional[List[torch.Size]], + axis=0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + return DTensor(Shardings(data, shapes, axis, dtype, device)) + + @classmethod + def from_sharding_list(cls, ctx: Context, data: List[torch.Tensor], num_partitions=16, axis=0): + shapes = [t.shape for t in data] + # TODO: validate according to axis + dtype = data[0].dtype + device = data[0].device + return cls.from_sharding_table( + ctx.computing.parallelize(data, partition=num_partitions, include_key=False), shapes, axis, dtype, device + ) + @classmethod def from_storages(cls, ctx, storages: List[torch.Tensor], d_axis=0, partitions=4): d_type = storages[0].dtype @@ -108,37 +103,92 @@ def from_storages(cls, ctx, storages: List[torch.Tensor], d_axis=0, partitions=4 blocks = ctx.computing.parallelize(enumerate(storages), partition=partitions, include_key=True) return DTensor(blocks, shape_size, DAxis(d_axis, partitions), d_type, device) - # @classmethod - # def elemwise_bc_op( - # cls, - # a: "DStorage", - # b: "DStorage", - # func: Callable[[LStorage, LStorage], LStorage], - # output_dtype=None, - # shape=None, - # **kwargs, - # ): - # # TODO: remove this - # def _apply_transpose(func, lf, rf): - # def _wrap(lblk, rblk): - # if lf: - # lblk = lblk.transpose() - # if rf: - # rblk = rblk.transpose() - # return func(lblk, rblk) - - # return _wrap - - # if isinstance(a, DStorage) and not isinstance(b, DStorage): - # func = _apply_transpose(func, a.transposed, False) - # output_blocks = a.blocks.mapValues(lambda x: func(x, b, **kwargs)) - # elif isinstance(b, DStorage) and not isinstance(a, DStorage): - # func = _apply_transpose(func, False, b.transposed) - # output_blocks = b.blocks.mapValues(lambda x: func(a, x, **kwargs)) - # else: - # raise RuntimeError("exactly one DStorage required") - # if output_dtype is None: - # output_dtype = a._dtype - # if shape is None: - # shape = a.shape - # return DStorage(output_blocks, shape, output_dtype, a._device) + +class Shardings: + def __init__( + self, + data: CTableABC, + shapes: Optional[List[torch.Size]] = None, + axis: int = 0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + self._data = data + self._axis = axis + + if shapes is None: + shards_shape = sorted(self._data.map(lambda k, s: (k, s.shape)).collect()) + self._shapes = [] + for i, (k, s) in enumerate(shards_shape): + assert i == k + self._shapes.append(s) + else: + self._shapes = shapes + + if dtype is None or device is None: + first_shard = self._data.first() + shard_dtype = cast(torch.dtype, first_shard.dtype) + shard_device = cast(torch.device, first_shard.device) + if dtype is not None: + assert dtype == shard_dtype + if device is not None: + assert device == shard_device + self._dtype = shard_dtype + self._device = shard_device + else: + self._dtype = dtype + self._device = device + + @property + def shape(self): + return self._shapes[0] + + @property + def dtype(self): + return self._dtype + + def with_dtype(self, dtype: torch.dtype): + self._dtype = dtype + return self + + @property + def device(self): + return self._device + + def __eq__(self, __o: object) -> bool: + if ( + isinstance(__o, Shardings) + and self.device == __o.device + and self.dtype == __o.dtype + and len(self._shapes) == len(__o._shapes) + ): + for s1, s2 in zip(self._shapes, __o._shapes): + if s1 != s2: + return False + return all(self._data.join(__o._data, lambda s1, s2: torch.allclose(s1, s2)).collect()) + return False + + def __str__(self) -> str: + return f"Sharding" + + def map_shard( + self, func: typing.Callable[[torch.Tensor], torch.Tensor], dtype_promote_to: Optional[torch.dtype] = None + ): + if dtype_promote_to is not None: + dtype = torch.promote_types(self.dtype, dtype_promote_to) + else: + dtype = self._dtype + return Shardings(self._data.mapValues(func), self._shapes, self._axis, dtype, self._device) + + def join_shard( + self, + other: "Shardings", + func: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + out_dtype: typing.Optional[torch.dtype] = None, + dtype_promote_to: Optional[torch.dtype] = None, + ): + if out_dtype is None: + out_dtype = torch.promote_types(self._dtype, other._dtype) + if dtype_promote_to is not None: + out_dtype = torch.promote_types(out_dtype, dtype_promote_to) + return Shardings(self._data.join(other._data, func), self._shapes, self._axis, out_dtype, self._device) diff --git a/python/fate/test/test_dtensor.py b/python/fate/test/test_dtensor.py new file mode 100644 index 0000000000..2dcc9fb48d --- /dev/null +++ b/python/fate/test/test_dtensor.py @@ -0,0 +1,44 @@ +import pytest +import torch +from fate.arch import Context +from fate.arch.computing.standalone import CSession +from fate.arch.context import Context +from fate.arch.federation.standalone import StandaloneFederation +from fate.arch.tensor import DTensor +from pytest import fixture + + +@fixture +def ctx(): + computing = CSession() + return Context( + "guest", + computing=computing, + federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]), + ) + + +@fixture +def t1_sharding(): + return [ + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ] + + +@fixture +def t1(ctx, t1_sharding): + return DTensor.from_sharding_list( + ctx, + t1_sharding, + num_partitions=3, + ) + + +@pytest.mark.parametrize( + "op", + [torch.exp, torch.log, torch.square], +) +def test_unary(ctx, t1, t1_sharding, op): + assert op(t1) == DTensor.from_sharding_list(ctx, [op(s) for s in t1_sharding], num_partitions=3)