Skip to content

Commit

Permalink
fix(tensor): enable true_divice
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Jan 10, 2023
1 parent ffeadd3 commit 95bc033
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions python/fate/arch/tensor/storage/local/device/cpu/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,21 @@ def min(self, *args, **kwargs):


def _ops_cpu_plain_unary_buildin(method, args, kwargs) -> Callable[[_TorchStorage], _TorchStorage]:
if method in {"exp", "log", "neg", "reciprocal", "square", "abs", "sum", "sqrt", "var", "std", "mean"}:
func = getattr(torch, method)
if (
func := {
"exp": torch.exp,
"log": torch.log,
"neg": torch.neg,
"reciprocal": torch.reciprocal,
"square": torch.square,
"abs": torch.abs,
"sum": torch.sum,
"sqrt": torch.sqrt,
"var": torch.var,
"std": torch.std,
"mean": torch.mean,
}.get(method)
) is not None:

def _wrap(storage: _TorchStorage) -> _TorchStorage:
output = func(storage.data, *args, **kwargs)
Expand Down Expand Up @@ -167,20 +180,21 @@ def _min(storage: _TorchStorage):


def _ops_cpu_plain_binary_buildin(method, args, kwargs) -> Callable[[Any, Any], _TorchStorage]:
if method in {
"add",
"sub",
"mul",
"div",
"pow",
"remainder",
"matmul",
"true_divide",
"maximum",
"minimum",
"truediv",
}:
func = getattr(torch, method)
if (
func := {
"add": torch.add,
"sub": torch.sub,
"mul": torch.mul,
"div": torch.div,
"pow": torch.pow,
"remainder": torch.remainder,
"matmul": torch.matmul,
"true_divide": torch.true_divide,
"truediv": torch.true_divide,
"maximum": torch.maximum,
"minimum": torch.minimum,
}.get(method)
) is not None:

def _wrap(a, b) -> _TorchStorage:
output = func(_maybe_unwrap_storage(a), _maybe_unwrap_storage(b), *args, **kwargs)
Expand Down

0 comments on commit 95bc033

Please sign in to comment.