Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower cummax op #8491

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ supported:
- count_nonzero
- count_nonzero.dim_IntList
- cross
- cummax
- cumprod
- cumsum
- detach_copy
Expand Down
18 changes: 18 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/torch.h>

#include <iostream>
#include <tuple>

#include "test/cpp/cpp_test_util.h"
#include "test/cpp/torch_xla_test.h"
Expand Down Expand Up @@ -2116,6 +2117,23 @@ TEST_F(AtenXlaTensorTest, TestCumProdCastLong) {
}
}

TEST_F(AtenXlaTensorTest, TestCumMax) {
torch::Tensor input = torch::rand({4, 3, 4});
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
std::tuple<torch::Tensor, torch::Tensor> result = torch::cummax(input, dim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
std::tuple<torch::Tensor, torch::Tensor> xla_result =
torch::cummax(xla_input, dim);
AllClose(std::get<0>(result), std::get<0>(xla_result));
AllClose(std::get<1>(result), std::get<1>(xla_result));
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestArgMin) {
torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false);
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def get_allowed_ops_map(
AllowedOpInfoEntry('complex'),
AllowedOpInfoEntry('copysign'),
AllowedOpInfoEntry('cross'),
AllowedOpInfoEntry('cummax'),
AllowedOpInfoEntry('cummin'),
AllowedOpInfoEntry('deg2rad'),
AllowedOpInfoEntry('div', 'no_rounding_mode'),
Expand Down Expand Up @@ -289,6 +288,7 @@ def get_allowed_ops_map(
# AllowedOpInfoEntry('cos'),
# AllowedOpInfoEntry('cosh'),
# AllowedOpInfoEntry('cov'),
# AllowedOpInfoEntry('cummax'),
# AllowedOpInfoEntry('cumsum'),
# AllowedOpInfoEntry('cumprod'),
# AllowedOpInfoEntry('diff'),
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,16 @@ at::Tensor XLANativeFunctions::cross(const at::Tensor& self,
XlaHelpers::I64Optional(dim)));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::cummax(
const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
std::tuple<XLATensorPtr, XLATensorPtr> res =
tensor_methods::cummax(self_tensor, dim);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)),
bridge::AtenFromXlaTensor(std::get<1>(res)));
}

at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim,
std::optional<at::ScalarType> dtype) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
31 changes: 31 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ xla::XlaComputation CreateComputation(
return ConsumeValue(builder.Build(op(x, y)));
}

xla::XlaComputation CreateMinMaxComputation(const std::string& name,
xla::PrimitiveType value_type,
xla::PrimitiveType index_type,
bool is_min) {
xla::XlaBuilder builder(name);
xla::XlaOp lhs_value = xla::Parameter(
&builder, 0, xla::ShapeUtil::MakeShape(value_type, {}), "lhs_value");
xla::XlaOp lhs_index = xla::Parameter(
&builder, 1, xla::ShapeUtil::MakeShape(index_type, {}), "lhs_index");
xla::XlaOp rhs_value = xla::Parameter(
&builder, 2, xla::ShapeUtil::MakeShape(value_type, {}), "rhs_value");
xla::XlaOp rhs_index = xla::Parameter(
&builder, 3, xla::ShapeUtil::MakeShape(index_type, {}), "rhs_index");

xla::XlaOp cmp =
is_min ? xla::Le(lhs_value, rhs_value) : xla::Ge(lhs_value, rhs_value);
xla::XlaOp max = xla::Select(cmp, lhs_value, rhs_value);
xla::XlaOp arg_max = xla::Select(cmp, lhs_index, rhs_index);
xla::XlaOp eq = xla::Eq(lhs_value, rhs_value);
xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index);
arg_max = xla::Select(eq, tie_id, arg_max);
xla::Tuple(&builder, {max, arg_max});
return ConsumeValue(builder.Build());
}

} // namespace

xla::PrecisionConfig::Precision XlaHelpers::s_mat_mul_precision =
Expand Down Expand Up @@ -229,6 +254,12 @@ xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) {
[&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); });
}

xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type) {
return CreateMinMaxComputation("MaxAndArgMaxComputation", value_type,
index_type, /*is_min=*/false);
}

std::vector<int64_t> XlaHelpers::SizesOfXlaOp(xla::XlaOp op) {
const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op);
return std::vector<int64_t>(op_shape.dimensions().begin(),
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class XlaHelpers {

static xla::XlaComputation CreateOrComputation(xla::PrimitiveType type);

static xla::XlaComputation CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type);

// Returns an XLA operation which is a reshape to the expected rank, by
// appending 1s to the major dimension. If offset is greater than zero, 1s
// will be prepened to the minor dimension as well.
Expand Down
70 changes: 70 additions & 0 deletions torch_xla/csrc/ops/cummax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "torch_xla/csrc/ops/cummax.h"

#include <torch/csrc/lazy/core/tensor_util.h>

#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {
namespace {

xla::XlaOp LowerCumMax(xla::XlaOp input, int64_t dim) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp value_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::MinValue(input_shape.element_type()));
xla::XlaOp index_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
xla::XlaOp iota =
xla::Iota(input.builder(),
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
input_shape.dimensions()),
dim);
xla::XlaComputation reducer = XlaHelpers::CreateMaxAndArgMaxComputation(
input_shape.element_type(), xla::PrimitiveType::S32);
return BuildCumulativeComputationWithIndices(
input, iota, dim, reducer, value_init_value, index_init_value);
}

xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp values_and_indices = LowerCumMax(operands[0], dim);
return values_and_indices;
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

} // namespace

CumMax::CumMax(const torch::lazy::Value& input, int64_t dim)
: XlaNode(
torch::lazy::OpKind(at::aten::cummax), {input},
[&]() { return NodeOutputShape(input, dim); },
/*num_outputs=*/2, torch::lazy::MHash(dim)),
dim_(dim) {}

torch::lazy::NodePtr CumMax::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumMax>(operands.at(0), dim_);
}

XlaOpVector CumMax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp values_and_indices = LowerCumMax(input, dim_);
return ReturnOps({xla::GetTupleElement(values_and_indices, 0),
xla::GetTupleElement(values_and_indices, 1)},
loctx);
}

std::string CumMax::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
return ss.str();
}

} // namespace torch_xla
28 changes: 28 additions & 0 deletions torch_xla/csrc/ops/cummax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_
#define XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_

#include <c10/core/ScalarType.h>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class CumMax : public XlaNode {
public:
CumMax(const torch::lazy::Value& input, int64_t dim);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

int64_t dim() const { return dim_; }

private:
int64_t dim_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMAX_H_
16 changes: 16 additions & 0 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(value_input);
std::vector<int64_t> window_strides(input_shape.rank(), 1);
std::vector<int64_t> window_dims(input_shape.rank(), 1);
window_dims[dim] = input_shape.dimensions(dim);
std::vector<std::pair<int64_t, int64_t>> padding(input_shape.rank());
padding[dim].first = input_shape.dimensions(dim) - 1;
return xla::ReduceWindowWithGeneralPadding(
{value_input, index_input}, {value_init, index_init}, reducer,
window_dims, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
}

xla::XlaOp BuildMean(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
return CreateSummation(input, dimensions, keep_reduced_dimensions,
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, int64_t dim,
const xla::XlaComputation& reducer,
xla::XlaOp init);

// Computes the cumulative computation specified by "reducer" and "init" in the
// given dimension "dim".
// Returns a tuple XlaOp (values, indices).
xla::XlaOp BuildCumulativeComputationWithIndices(
xla::XlaOp value_input, xla::XlaOp index_input, int64_t dim,
const xla::XlaComputation& reducer, xla::XlaOp value_init,
xla::XlaOp index_init);

xla::XlaOp BuildAll(xla::XlaOp input, absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions);

Expand Down
Loading
Loading