Skip to content

Commit

Permalink
Properly lower add and mul
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Mar 13, 2024
1 parent 50ae882 commit 997c79a
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 63 deletions.
24 changes: 24 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,4 +500,28 @@ xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) {
return sub_result;
}

xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha) {
// Three-way shape and value promotion
std::tie(input, other) = XlaHelpers::Promote(input, other);
std::tie(input, alpha) = XlaHelpers::Promote(input, alpha);
std::tie(input, other) = XlaHelpers::Promote(input, other);

xla::XlaOp multiplied =
xla::Mul(other, alpha, XlaHelpers::getBroadcastDimensions(other, alpha));
xla::XlaOp add_result = xla::Add(
input, multiplied, XlaHelpers::getBroadcastDimensions(input, multiplied));

return add_result;
}

xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other) {
// Shape and value promotion
std::tie(input, other) = XlaHelpers::Promote(input, other);

xla::XlaOp mul_result =
xla::Mul(input, other, XlaHelpers::getBroadcastDimensions(input, other));

return mul_result;
}

} // namespace torch_xla
8 changes: 8 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ xla::XlaOp BuildRsub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);
// out = input − alpha * other
xla::XlaOp BuildSub(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);

// Compuate the add function. Adds other, scaled by alpha, from input.
// out = input − alpha * other
xla::XlaOp BuildAdd(xla::XlaOp input, xla::XlaOp other, xla::XlaOp alpha);

// Compuate the mul function.
// out = input * other
xla::XlaOp BuildMul(xla::XlaOp input, xla::XlaOp other);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_ELEMENTWISE_H_
167 changes: 109 additions & 58 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ namespace torch_xla {
promoted.first, promoted.second)), \
loctx); \
}; \
return GenericOp( \
torch::lazy::OpKind(sym), {input0, input1}, \
[&]() { \
return InferOutputShape({GetXlaShape(input0), GetXlaShape(input1)}, \
shape_fn); \
}, \
std::move(lower_fn)); \
return GenericOp(torch::lazy::OpKind(sym), {input0, input1}, \
[&]() { \
return InferOutputShape( \
{GetXlaShape(input0), GetXlaShape(input1)}, \
shape_fn); \
}, \
std::move(lower_fn)); \
}

PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
Expand Down Expand Up @@ -331,13 +331,13 @@ torch::lazy::NodePtr Dot(const torch::lazy::Value& input,
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildDot(operands[0], operands[1]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::mm), {input, weight},
[&]() {
return InferOutputShape({GetXlaShape(input), GetXlaShape(weight)},
lower_for_shape_fn);
},
std::move(lower_fn));
return GenericOp(torch::lazy::OpKind(at::aten::mm), {input, weight},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(weight)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr MatMul(const torch::lazy::Value& lhs,
Expand All @@ -354,13 +354,13 @@ torch::lazy::NodePtr MatMul(const torch::lazy::Value& lhs,
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return CreateMatMul(operands[0], operands[1]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::matmul), {lhs, rhs},
[&]() {
return InferOutputShape({GetXlaShape(lhs), GetXlaShape(rhs)},
lower_for_shape_fn);
},
std::move(lower_fn));
return GenericOp(torch::lazy::OpKind(at::aten::matmul), {lhs, rhs},
[&]() {
return InferOutputShape(
{GetXlaShape(lhs), GetXlaShape(rhs)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
Expand All @@ -380,14 +380,14 @@ torch::lazy::NodePtr AdaptiveMaxPool2dBackward(
/*input=*/operands[1],
/*pool_dim=*/2);
};
return GenericOp(
torch::lazy::OpKind(at::aten::adaptive_max_pool2d_backward),
{grad_output, input},
[&]() {
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
lower_for_shape_fn);
},
std::move(lower_fn));
return GenericOp(torch::lazy::OpKind(at::aten::adaptive_max_pool2d_backward),
{grad_output, input},
[&]() {
return InferOutputShape(
{GetXlaShape(grad_output), GetXlaShape(input)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
Expand All @@ -404,13 +404,13 @@ torch::lazy::NodePtr ComparisonOp(c10::Symbol kind,
[kind](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildComparisonOp(kind, operands[0], operands[1]);
};
return GenericOp(
torch::lazy::OpKind(kind), {input, other},
[&]() {
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
return GenericOp(torch::lazy::OpKind(kind), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr Where(const torch::lazy::Value& condition,
Expand Down Expand Up @@ -584,13 +584,13 @@ torch::lazy::NodePtr Pdist_forward(const torch::lazy::Value& input,
torch::lazy::NodePtr tmp = input - torch::lazy::MakeNode<Unsqueeze>(input, 1);
torch::lazy::NodePtr result_matrix = Norm(tmp, p, dtype, {2}, false);

return GenericOp(
torch::lazy::OpKind(at::aten::_pdist_forward), {result_matrix},
[&]() {
return InferOutputShape({GetXlaShape(result_matrix)},
lower_for_shape_fn);
},
std::move(lower_fn), 1);
return GenericOp(torch::lazy::OpKind(at::aten::_pdist_forward),
{result_matrix},
[&]() {
return InferOutputShape({GetXlaShape(result_matrix)},
lower_for_shape_fn);
},
std::move(lower_fn), 1);
}

torch::lazy::NodePtr LinalgVectorNorm(const torch::lazy::Value& input,
Expand Down Expand Up @@ -845,13 +845,13 @@ torch::lazy::NodePtr XLogY(const torch::lazy::Value& input,
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
return BuildXLogY(operands[0], operands[1]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::xlogy), {input, other},
[&]() {
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
return GenericOp(torch::lazy::OpKind(at::aten::xlogy), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr NanToNum(const torch::lazy::Value& input,
Expand Down Expand Up @@ -890,12 +890,12 @@ torch::lazy::NodePtr SLogDet(const torch::lazy::Value& input) {
return xla::Tuple(operands[0].builder(), {result.sign, result.logdet});
};

return GenericOp(
torch::lazy::OpKind(at::aten::slogdet), {input},
[&]() {
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
},
std::move(lower_fn), /*num_outputs=*/2);
return GenericOp(torch::lazy::OpKind(at::aten::slogdet), {input},
[&]() {
return InferOutputShape({GetXlaShape(input)},
lower_for_shape_fn);
},
std::move(lower_fn), /*num_outputs=*/2);
}

torch::lazy::NodePtr Softplus(const torch::lazy::Value& input,
Expand Down Expand Up @@ -991,8 +991,8 @@ torch::lazy::NodePtr Rsub(const torch::lazy::Value& input,
}

torch::lazy::NodePtr Sub(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha) {
const torch::lazy::Value& other,
const torch::lazy::Value& alpha) {
torch::lazy::ScopePusher ir_scope(at::aten::sub.toQualString());
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
Expand All @@ -1017,4 +1017,55 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input,
std::move(lower_fn));
}

torch::lazy::NodePtr Add(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha) {
torch::lazy::ScopePusher ir_scope(at::aten::add.toQualString());
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_alpha = loctx->GetOutputOp(node.operand(2));
xla::XlaOp xla_output = BuildAdd(xla_input, xla_other, xla_alpha);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands";
return BuildAdd(operands[0], operands[1], operands[2]);
};
return GenericOp(
torch::lazy::OpKind(at::aten::add), {input, other, alpha},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other), GetXlaShape(alpha)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr Mul(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
torch::lazy::ScopePusher ir_scope(at::aten::mul.toQualString());
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_output = BuildMul(xla_input, xla_other);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
return BuildMul(operands[0], operands[1]);
};
return GenericOp(torch::lazy::OpKind(at::aten::mul), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)},
lower_for_shape_fn);
},
std::move(lower_fn));
}

} // namespace torch_xla
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ torch::lazy::NodePtr Sub(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha);

torch::lazy::NodePtr Add(const torch::lazy::Value& input,
const torch::lazy::Value& other,
const torch::lazy::Value& alpha);

torch::lazy::NodePtr Mul(const torch::lazy::Value& input,
const torch::lazy::Value& other);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_H_
12 changes: 7 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,9 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
sym_int_elements, logical_element_type, device);
}

return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant,
logical_element_type);
return input->CreateFrom(
Add(input->GetIrValue(), other->GetIrValue(), constant),
logical_element_type);
}

XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
Expand All @@ -787,8 +788,9 @@ XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);

return input->CreateFrom(
input->GetIrValue() + other_constant * alpha_constant,
Add(input->GetIrValue(), other_constant, alpha_constant),
logical_element_type);
}

Expand Down Expand Up @@ -1878,7 +1880,7 @@ XLATensorPtr mse_loss_backward(const XLATensorPtr& grad_output,

XLATensorPtr mul(const XLATensorPtr& input, const XLATensorPtr& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(input->GetIrValue() * other->GetIrValue(),
return input->CreateFrom(Mul(input->GetIrValue(), other->GetIrValue()),
logical_element_type);
}

Expand All @@ -1890,7 +1892,7 @@ XLATensorPtr mul(const XLATensorPtr& input, const at::Scalar& other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
return input->CreateFrom(input->GetIrValue() * constant,
return input->CreateFrom(Mul(input->GetIrValue(), constant),
logical_element_type);
}

Expand Down

0 comments on commit 997c79a

Please sign in to comment.