diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 06dbab66fc7d..6fccfe328004 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -714,7 +714,7 @@ def test_xla_sharded_hlo_dump(self): partition_spec) xst2 = xst1 + 5 hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor]) - self.assertIn('%p1.4 = f32[1,8]{1,0} parameter(1), sharding', hlo) + self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo) # scalar 5 should be implicitly replicated, so the pre-optimization HLO # shouldn't mark it with sharding. self.assertNotIn('%p0.2 = f32[] parameter(0), sharding={replicated}', hlo) @@ -831,7 +831,7 @@ def test_mark_sharding_ir(self): actual += 0 hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)', + '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)', hlo) self.assertTrue(torch.allclose(expected, actual.cpu())) diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 24beecc2d2c8..f3320eef49f7 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -500,4 +500,28 @@ xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) { return sub_result; } +xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) { + // Three-way shape and value promotion + std::tie(input, other) = XlaHelpers::Promote(input, other); + std::tie(input, alpha) = XlaHelpers::Promote(input, alpha); + std::tie(input, other) = XlaHelpers::Promote(input, other); + + xla::XlaOp multiplied = + xla::Mul(other, alpha, XlaHelpers::getBroadcastDimensions(other, alpha)); + xla::XlaOp add_result = xla::Add( + input, multiplied, XlaHelpers::getBroadcastDimensions(input, multiplied)); + + return add_result; +} + +xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other) { + // Shape and value promotion + std::tie(input, other) = XlaHelpers::Promote(input, other); + + xla::XlaOp mul_result = + xla::Mul(input, other, XlaHelpers::getBroadcastDimensions(input, other)); + + return mul_result; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 9262ea248817..0db3ffd79e9e 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -117,14 +117,22 @@ xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output, // based on a scalar or tensor weight and returns the resulting out tensor. xla::XlaOp BuildLerp(xla::XlaOp start, xla::XlaOp end, xla::XlaOp weight); -// Compuate the rsub function. Subtracts input, scaled by alpha, from other. +// Computes the rsub function. Subtracts input, scaled by alpha, from other. // out = other − alpha * input xla::XlaOp BuildRsub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha); -// Compuate the sub function. Subtracts other, scaled by alpha, from input. +// Computes the sub function. Subtracts other, scaled by alpha, from input. // out = input − alpha * other xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha); +// Computes the add function. Adds other, scaled by alpha, from input. +// out = input + alpha * other +xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha); + +// Computes the mul function. +// out = input * other +xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_ELEMENTWISE_H_ \ No newline at end of file diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index d9b9085542a9..aef9b825e087 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -1017,4 +1017,55 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input, std::move(lower_fn)); } +torch::lazy::NodePtr Add(const torch::lazy::Value& input, + const torch::lazy::Value& other, + const torch::lazy::Value& alpha) { + torch::lazy::ScopePusher ir_scope(at::aten::add.toQualString()); + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_alpha = loctx->GetOutputOp(node.operand(2)); + xla::XlaOp xla_output = BuildAdd(xla_input, xla_other, xla_alpha); + return node.ReturnOp(xla_output, loctx); + }; + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands"; + return BuildAdd(operands[0], operands[1], operands[2]); + }; + return GenericOp( + torch::lazy::OpKind(at::aten::add), {input, other, alpha}, + [&]() { + return InferOutputShape( + {GetXlaShape(input), GetXlaShape(other), GetXlaShape(alpha)}, + lower_for_shape_fn); + }, + std::move(lower_fn)); +} + +torch::lazy::NodePtr Mul(const torch::lazy::Value& input, + const torch::lazy::Value& other) { + torch::lazy::ScopePusher ir_scope(at::aten::mul.toQualString()); + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_output = BuildMul(xla_input, xla_other); + return node.ReturnOp(xla_output, loctx); + }; + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands"; + return BuildMul(operands[0], operands[1]); + }; + return GenericOp( + torch::lazy::OpKind(at::aten::mul), {input, other}, + [&]() { + return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, + lower_for_shape_fn); + }, + std::move(lower_fn)); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 99efd7f87e1d..adef3edc117d 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -253,6 +253,13 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input, const torch::lazy::Value& other, const torch::lazy::Value& alpha); +torch::lazy::NodePtr Add(const torch::lazy::Value& input, + const torch::lazy::Value& other, + const torch::lazy::Value& alpha); + +torch::lazy::NodePtr Mul(const torch::lazy::Value& input, + const torch::lazy::Value& other); + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_OPS_OPS_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 4a22afefe03b..fb6d0c17e33e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -767,8 +767,9 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other, sym_int_elements, logical_element_type, device); } - return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant, - logical_element_type); + return input->CreateFrom( + Add(input->GetIrValue(), other->GetIrValue(), constant), + logical_element_type); } XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other, @@ -787,8 +788,9 @@ XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other, xla::ShapeUtil::MakeScalarShape( MakeXlaPrimitiveType(input->dtype(), &device)), logical_element_type, device); + return input->CreateFrom( - input->GetIrValue() + other_constant * alpha_constant, + Add(input->GetIrValue(), other_constant, alpha_constant), logical_element_type); } @@ -1878,7 +1880,7 @@ XLATensorPtr mse_loss_backward(const XLATensorPtr& grad_output, XLATensorPtr mul(const XLATensorPtr& input, const XLATensorPtr& other, c10::optional logical_element_type) { - return input->CreateFrom(input->GetIrValue() * other->GetIrValue(), + return input->CreateFrom(Mul(input->GetIrValue(), other->GetIrValue()), logical_element_type); } @@ -1890,7 +1892,7 @@ XLATensorPtr mul(const XLATensorPtr& input, const at::Scalar& other, xla::ShapeUtil::MakeScalarShape( MakeXlaPrimitiveType(input->dtype(), &device)), logical_element_type, device); - return input->CreateFrom(input->GetIrValue() * constant, + return input->CreateFrom(Mul(input->GetIrValue(), constant), logical_element_type); }