Skip to content

Commit

Permalink
fix(tensor): fix unary and binary ops for dtensor
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Mar 13, 2023
1 parent fc12aa1 commit 19d1629
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 101 deletions.
1 change: 1 addition & 0 deletions python/fate/arch/tensor/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._op_binary import *
from ._ops_agg import *
from ._ops_cipher import *
from ._ops_unary import *
from ._tensor import DTensor

__all__ = ["DTensor"]
27 changes: 14 additions & 13 deletions python/fate/arch/tensor/distributed/_op_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,38 @@ def sub(input, other):
return _binary(input, other, torch.sub)


@implements(torch.rsub)
def rsub(input, other):
return _binary(input, other, torch.rsub)


@implements(torch.mul)
def mul(input, other):
return _binary(input, other, torch.mul)


@implements(torch.div)
def div(input, other):
return _binary(input, other, torch.div)
return _binary(input, other, torch.div, dtype_promote_to=torch.float32)


def _binary(input, other, op, swap=False):
def _binary(input, other, op, swap_operad=False, dtype_promote_to=None):
# swap input and output if input is not DStroage
if not isinstance(input, DTensor):
return _binary(op, other, input, swap=not swap)
return _binary(op, other, input, swap_operad=not swap_operad, dtype_promote_to=dtype_promote_to)

# input and other both DStorage
# TODO: validate
if isinstance(other, DTensor):
if swap:
output_blocks = other.blocks.join(input.blocks, op)
if swap_operad:
return DTensor(other.shardings.join_shard(input.shardings, op, dtype_promote_to=dtype_promote_to))
else:
output_blocks = input.blocks.join(other.blocks, op)
output_dtype = torch.promote_types(input.dtype, other.dtype)
output_shape = torch.broadcast_shapes(input.shape, other.shape)
return DTensor(output_blocks, output_shape, input._d_axis, output_dtype, input._device)
return DTensor(input.shardings.join_shard(other.shardings, op, dtype_promote_to=dtype_promote_to))

# other is local tensor, broadcast to partitions
# TODO: validate broadcast
else:
if swap:
output_blocks = input.blocks.mapValues(lambda x: op(other, x))
if swap_operad:
return DTensor(input.shardings.map_shard(lambda x: op(other, x, dtype_promote_to=dtype_promote_to)))
else:
output_blocks = input.blocks.mapValues(lambda x: op(x, other))
return DTensor(output_blocks, input.shape, input._d_axis, input._dtype, input.device)
return DTensor(input.shardings.map_shard(lambda x: op(x, other), dtype_promote_to=dtype_promote_to))
16 changes: 2 additions & 14 deletions python/fate/arch/tensor/distributed/_ops_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,9 @@

@implements(_custom_ops.encrypt)
def encrypt(input: DTensor, encryptor):
return DTensor(
input.blocks.mapValues(lambda x: _custom_ops.encrypt(x, encryptor)),
input.shape,
input.d_axis,
input.dtype,
input.device,
)
return DTensor(input.shardings.map_shard(lambda x: _custom_ops.encrypt(x, encryptor), input.dtype))


@implements(_custom_ops.decrypt)
def decrypt(input: DTensor, decryptor):
return DTensor(
input.blocks.mapValues(lambda x: _custom_ops.decrypt(x, decryptor)),
input.shape,
input.d_axis,
input.dtype,
input.device,
)
return DTensor(input.shardings.map_shard(lambda x: _custom_ops.decrypt(x, decryptor), input.dtype))
18 changes: 18 additions & 0 deletions python/fate/arch/tensor/distributed/_ops_unary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from ._tensor import DTensor, implements


@implements(torch.exp)
def exp(input: DTensor):
return DTensor(input.shardings.map_shard(torch.exp, dtype_promote_to=torch.float32))


@implements(torch.log)
def log(input: DTensor):
return DTensor(input.shardings.map_shard(torch.log, dtype_promote_to=torch.float32))


@implements(torch.square)
def square(input: DTensor):
return DTensor(input.shardings.map_shard(torch.square))
198 changes: 124 additions & 74 deletions python/fate/arch/tensor/distributed/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import List
import typing
from typing import List, Optional, cast

import torch
from fate.arch.computing import CTableABC
from fate.arch.context import Context

_HANDLED_FUNCTIONS = {}

Expand All @@ -31,27 +35,7 @@ def decorator(func):
return decorator


class DAxis:
def __init__(self, axis: int, partitions) -> None:
self.axis = axis
self.partitions = partitions

def __str__(self) -> str:
return f"DAxis<axis={self.axis}, partitions={self.partitions}>"


class DTensor:
def __init__(
self, blocks, shape: torch.Size, d_axis: DAxis, dtype: torch.dtype, device: torch.device, transposed=False
) -> None:
self.blocks = blocks
self._shape = shape
self._dtype = dtype
self._device = device
self._d_axis = d_axis
# TODO: fix me
self.transposed = transposed

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
Expand All @@ -60,41 +44,52 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return NotImplemented
return _HANDLED_FUNCTIONS[func](*args, **kwargs)

@property
def shape(self):
return self._shape
def __init__(self, shardings: "Shardings") -> None:
self.shardings = shardings

@property
def d_axis(self) -> DAxis:
return self._d_axis
def shape(self):
return self.shardings.shape

@property
def dtype(self):
return self._dtype
return self.shardings.dtype

@property
def device(self):
return self._device

def transpose(self) -> "DTensor":
return DTensor(self.blocks, self.shape.transpose(), self.dtype, self.device, not self.transposed)
return self.shardings.device

def __eq__(self, __o: object) -> bool:
if isinstance(__o, DTensor) and self._dtype == __o.dtype and self._device == __o.device:
return torch.allclose(self.to_local(), __o.to_local())
else:
return False
return isinstance(__o, DTensor) and self.shardings == __o.shardings

def __str__(self) -> str:
return f"DStorage({self.device}, {self.dtype}, {self.shape})"

def collect(self) -> List[torch.Tensor]:
return [pair[1] for pair in sorted(self.blocks.collect())]
return f"<DTensor(shardings={self.shardings})>"

def to_local(self) -> torch.Tensor:
storages = self.collect()
storages = [pair[1] for pair in sorted(self.blocks.collect())]
return torch.cat(storages, self._d_axis.axis)

@classmethod
def from_sharding_table(
cls,
data: CTableABC,
shapes: Optional[List[torch.Size]],
axis=0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
return DTensor(Shardings(data, shapes, axis, dtype, device))

@classmethod
def from_sharding_list(cls, ctx: Context, data: List[torch.Tensor], num_partitions=16, axis=0):
shapes = [t.shape for t in data]
# TODO: validate according to axis
dtype = data[0].dtype
device = data[0].device
return cls.from_sharding_table(
ctx.computing.parallelize(data, partition=num_partitions, include_key=False), shapes, axis, dtype, device
)

@classmethod
def from_storages(cls, ctx, storages: List[torch.Tensor], d_axis=0, partitions=4):
d_type = storages[0].dtype
Expand All @@ -108,37 +103,92 @@ def from_storages(cls, ctx, storages: List[torch.Tensor], d_axis=0, partitions=4
blocks = ctx.computing.parallelize(enumerate(storages), partition=partitions, include_key=True)
return DTensor(blocks, shape_size, DAxis(d_axis, partitions), d_type, device)

# @classmethod
# def elemwise_bc_op(
# cls,
# a: "DStorage",
# b: "DStorage",
# func: Callable[[LStorage, LStorage], LStorage],
# output_dtype=None,
# shape=None,
# **kwargs,
# ):
# # TODO: remove this
# def _apply_transpose(func, lf, rf):
# def _wrap(lblk, rblk):
# if lf:
# lblk = lblk.transpose()
# if rf:
# rblk = rblk.transpose()
# return func(lblk, rblk)

# return _wrap

# if isinstance(a, DStorage) and not isinstance(b, DStorage):
# func = _apply_transpose(func, a.transposed, False)
# output_blocks = a.blocks.mapValues(lambda x: func(x, b, **kwargs))
# elif isinstance(b, DStorage) and not isinstance(a, DStorage):
# func = _apply_transpose(func, False, b.transposed)
# output_blocks = b.blocks.mapValues(lambda x: func(a, x, **kwargs))
# else:
# raise RuntimeError("exactly one DStorage required")
# if output_dtype is None:
# output_dtype = a._dtype
# if shape is None:
# shape = a.shape
# return DStorage(output_blocks, shape, output_dtype, a._device)

class Shardings:
def __init__(
self,
data: CTableABC,
shapes: Optional[List[torch.Size]] = None,
axis: int = 0,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
self._data = data
self._axis = axis

if shapes is None:
shards_shape = sorted(self._data.map(lambda k, s: (k, s.shape)).collect())
self._shapes = []
for i, (k, s) in enumerate(shards_shape):
assert i == k
self._shapes.append(s)
else:
self._shapes = shapes

if dtype is None or device is None:
first_shard = self._data.first()
shard_dtype = cast(torch.dtype, first_shard.dtype)
shard_device = cast(torch.device, first_shard.device)
if dtype is not None:
assert dtype == shard_dtype
if device is not None:
assert device == shard_device
self._dtype = shard_dtype
self._device = shard_device
else:
self._dtype = dtype
self._device = device

@property
def shape(self):
return self._shapes[0]

@property
def dtype(self):
return self._dtype

def with_dtype(self, dtype: torch.dtype):
self._dtype = dtype
return self

@property
def device(self):
return self._device

def __eq__(self, __o: object) -> bool:
if (
isinstance(__o, Shardings)
and self.device == __o.device
and self.dtype == __o.dtype
and len(self._shapes) == len(__o._shapes)
):
for s1, s2 in zip(self._shapes, __o._shapes):
if s1 != s2:
return False
return all(self._data.join(__o._data, lambda s1, s2: torch.allclose(s1, s2)).collect())
return False

def __str__(self) -> str:
return f"Sharding<shapes={self._shapes}, dtype={self._dtype}, device={self._device}>"

def map_shard(
self, func: typing.Callable[[torch.Tensor], torch.Tensor], dtype_promote_to: Optional[torch.dtype] = None
):
if dtype_promote_to is not None:
dtype = torch.promote_types(self.dtype, dtype_promote_to)
else:
dtype = self._dtype
return Shardings(self._data.mapValues(func), self._shapes, self._axis, dtype, self._device)

def join_shard(
self,
other: "Shardings",
func: typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
out_dtype: typing.Optional[torch.dtype] = None,
dtype_promote_to: Optional[torch.dtype] = None,
):
if out_dtype is None:
out_dtype = torch.promote_types(self._dtype, other._dtype)
if dtype_promote_to is not None:
out_dtype = torch.promote_types(out_dtype, dtype_promote_to)
return Shardings(self._data.join(other._data, func), self._shapes, self._axis, out_dtype, self._device)
44 changes: 44 additions & 0 deletions python/fate/test/test_dtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import torch
from fate.arch import Context
from fate.arch.computing.standalone import CSession
from fate.arch.context import Context
from fate.arch.federation.standalone import StandaloneFederation
from fate.arch.tensor import DTensor
from pytest import fixture


@fixture
def ctx():
computing = CSession()
return Context(
"guest",
computing=computing,
federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]),
)


@fixture
def t1_sharding():
return [
torch.tensor([[1, 2, 3], [4, 5, 6]]),
torch.tensor([[1, 2, 3], [4, 5, 6]]),
torch.tensor([[1, 2, 3], [4, 5, 6]]),
]


@fixture
def t1(ctx, t1_sharding):
return DTensor.from_sharding_list(
ctx,
t1_sharding,
num_partitions=3,
)


@pytest.mark.parametrize(
"op",
[torch.exp, torch.log, torch.square],
)
def test_unary(ctx, t1, t1_sharding, op):
assert op(t1) == DTensor.from_sharding_list(ctx, [op(s) for s in t1_sharding], num_partitions=3)

0 comments on commit 19d1629

Please sign in to comment.