Skip to content

Commit

Permalink
fix(ops): fix matmul
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jan 5, 2023
1 parent e275ce9 commit 7ab709c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
31 changes: 18 additions & 13 deletions python/fate/arch/tensor/ops/_matmul_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .._tensor import Tensor
from .._tensor import DStorage, Tensor
from ..types import Shape
from ._ops import _get_dispatch_info, dispatch_signature2

Expand Down Expand Up @@ -37,19 +37,24 @@ def matmul(a: Tensor, b: Tensor) -> Tensor:
if mul_shape_a.size[-1] != mul_shape_b.size[0]:
raise ValueError("matmul: dimension mismatch: should be (..., n) x (...,n,?)")

if mul_shape_a.is_d_axis(-2):
raise ValueError(f"not supported distributed axis position (...,d,?) for left tensor {a}")
if mul_shape_b.is_d_axis(-1):
raise ValueError("not supported distributed axis position (...,?,d) for right tensor {b}")
if mul_shape_a.is_d_axis(-2) and mul_shape_b.is_d_axis(-1):
raise ValueError(
f"not supported distributed axis position (...,d,?) for left tensor {a} and distributed axis position (...,?,d) for right tensor {b}"
)

out_storage = a.storage.blocks.join(
b.storage.blocks,
apply_transpose(
local_ops.matmul,
a.storage.transposed,
b.storage.transposed,
),
).reduce(local_ops.add)
if mul_shape_a.is_d_axis(-2) and mul_shape_b.d_axis is None:
out_storage = DStorage.elemwise_bc_op(a.storage, b.storage, lambda l, r: local_ops.matmul(l, r))
elif mul_shape_b.is_d_axis(-1) and mul_shape_a.d_axis is None:
out_storage = DStorage.elemwise_bc_op(a.storage, b.storage, lambda l, r: local_ops.matmul(l, r))
else:
out_storage = a.storage.blocks.join(
b.storage.blocks,
apply_transpose(
local_ops.matmul,
a.storage.transposed,
b.storage.transposed,
),
).reduce(local_ops.add)
return Tensor(out_storage)


Expand Down
2 changes: 1 addition & 1 deletion python/fate/arch/tensor/types/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class DAxis:
def __init__(self, axis, partitions) -> None:
def __init__(self, axis: int, partitions) -> None:
self.axis = axis
self.partitions = partitions

Expand Down
61 changes: 61 additions & 0 deletions python/test/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
from fate.arch import Context, tensor
from fate.arch.computing.standalone import CSession
from fate.arch.context import Context
from fate.arch.federation.standalone import StandaloneFederation
from pytest import fixture


@fixture
def ctx():
computing = CSession()
return Context(
"guest",
computing=computing,
federation=StandaloneFederation(computing, "fed", ("guest", 10000), [("host", 9999)]),
)


@fixture
def t1(ctx):
return tensor.distributed_tensor(
ctx,
[
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
],
)


@fixture
def t3():
return tensor.tensor(
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]),
)


@fixture
def t2(ctx):
return tensor.distributed_tensor(
ctx,
[
torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]),
torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]),
torch.tensor([[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]),
],
d_axis=1,
)


@fixture
def t4():
return torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
)


def test_1(t1, t3):
print(t1.to_local())
print(t3)
print(tensor.matmul(t1, t3).to_local())

0 comments on commit 7ab709c

Please sign in to comment.