diff --git a/python/fate/arch/tensor/distributed/_ops_binary.py b/python/fate/arch/tensor/distributed/_ops_binary.py index c6370fa85d..fe7b143d53 100644 --- a/python/fate/arch/tensor/distributed/_ops_binary.py +++ b/python/fate/arch/tensor/distributed/_ops_binary.py @@ -50,9 +50,10 @@ def _binary(input, other, op, swap_operad=False, dtype_promote_to=None): if swap_operad: return DTensor( input.shardings.map_shard( - lambda x: op(other, x, dtype_promote_to=dtype_promote_to), shapes=shapes.shapes, axis=shapes.axis + lambda x: op(other, x), dtype_promote_to=dtype_promote_to, shapes=shapes.shapes, axis=shapes.axis ) ) + else: return DTensor( input.shardings.map_shard(