Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower cummin op #8565

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ supported:
- count_nonzero.dim_IntList
- cross
- cummax
- cummin
- cumprod
- cumsum
- detach_copy
Expand Down
17 changes: 17 additions & 0 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,6 +2134,23 @@ TEST_F(AtenXlaTensorTest, TestCumMax) {
ExpectCounterChanged("xla::cummax", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestCumMin) {
torch::Tensor input = torch::rand({4, 3, 4});
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
std::tuple<torch::Tensor, torch::Tensor> result = torch::cummin(input, dim);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
std::tuple<torch::Tensor, torch::Tensor> xla_result =
torch::cummin(xla_input, dim);
AllClose(std::get<0>(result), std::get<0>(xla_result));
AllClose(std::get<1>(result), std::get<1>(xla_result));
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::cummin", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestArgMin) {
torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::argmin(a, std::nullopt, /*keepdim=*/false);
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def get_allowed_ops_map(
AllowedOpInfoEntry('complex'),
AllowedOpInfoEntry('copysign'),
AllowedOpInfoEntry('cross'),
AllowedOpInfoEntry('cummin'),
AllowedOpInfoEntry('deg2rad'),
AllowedOpInfoEntry('div', 'no_rounding_mode'),
AllowedOpInfoEntry('div', 'trunc_rounding'),
Expand Down Expand Up @@ -289,6 +288,7 @@ def get_allowed_ops_map(
# AllowedOpInfoEntry('cosh'),
# AllowedOpInfoEntry('cov'),
# AllowedOpInfoEntry('cummax'),
# AllowedOpInfoEntry('cummin'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyy-martin Thanks for your contributing.

Could you add a test for the op execution result's correctness?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this same note, could you add a test where dim is a 0-sized dimension?
Since this is using the same lowering as cummax, it could also have the same problems. ref: #8610

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has the same problem as cummax, as well as old ops such as cumprod and cumsum. I assume that was the reason those tests were excluded in the first place.

We can consider modify the op tests to have the option to exclude scalars/0-dimension values to cover other test cases. The C++ tests in this PR have verified the correctness though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now the SampleInput includes scalers by default here https://github.com/pytorch/pytorch/blob/dddf52b1b91d473e249829dad6c705c13624b35f/torch/testing/_internal/common_methods_invocations.py#L6942. I believe we can override the sample function in our test_ops.py to exclude the zero-dimension test cases. As for why the current implementation does not support scalar, i think we will need more investigation.

Copy link
Collaborator

@yaochengji yaochengji Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the test suite only support 0-sized dim, we can write simple tests from scratch to compare the result of torch/xla with native pytorch on cpu.

Copy link
Collaborator

@pgmoka pgmoka Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @yaochengji here to include a test for result correctness.

Otherwise, LGTM.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the discussion about 0-sized dimensions here, so as not to flood the other thread.
Do you think it would be hard to add support to 0-sized dimensions?

I'm asking for that, because I think this is a BC-breaking change. Before your PR, cummin would fallback to CUDA, which would work with 0-sized dimensions. Now, it crashes (#8610).

That said, since other operations (e.g. cummax) also have this problem, I think we could merge this and work out a solution to this issue afterwards.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be hard to add support to 0-sized dimensions?

I am not sure, the error suggests it might have come from the XLA backend: https://github.com/openxla/xla/blob/e795171be9897e1356b1d98588e4fe784e2fc1bb/xla/service/shape_inference.cc#L202. will need more time to figure out a fix.

That said, since other operations (e.g. cummax) also have this problem, I think we could merge this and work out a solution to this issue afterwards.

Agree. It also applies to core ops like cumprod/cumsum which were lowered since the beginning of ptxla.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the XLA backend doesn't work with 0-sized dimensions, maybe we could catch this while tracing (since the shapes should be static). Then, return 2 0-sized tensors.

That said, I believe this could be left for a future PR (maybe also solving the other ops that have the same issue).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Is there anything else you'd want me to add in this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think it would be nice to have at least one test in the Python side, as @yaochengji suggested.
Other than that, LGTM.

# AllowedOpInfoEntry('cumsum'),
# AllowedOpInfoEntry('cumprod'),
# AllowedOpInfoEntry('diff'),
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,16 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::cummax(
bridge::AtenFromXlaTensor(std::get<1>(res)));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::cummin(
const at::Tensor& self, int64_t dim) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
std::tuple<XLATensorPtr, XLATensorPtr> res =
tensor_methods::cummin(self_tensor, dim);
return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(res)),
bridge::AtenFromXlaTensor(std::get<1>(res)));
}

at::Tensor XLANativeFunctions::cumprod(const at::Tensor& self, int64_t dim,
std::optional<at::ScalarType> dtype) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ xla::XlaComputation XlaHelpers::CreateMaxAndArgMaxComputation(
index_type, /*is_min=*/false);
}

xla::XlaComputation XlaHelpers::CreateMinAndArgMinComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type) {
return CreateMinMaxComputation("MinAndArgMinComputation", value_type,
index_type, /*is_min=*/true);
}

std::vector<int64_t> XlaHelpers::SizesOfXlaOp(xla::XlaOp op) {
const xla::Shape& op_shape = ShapeHelper::ShapeOfXlaOp(op);
return std::vector<int64_t>(op_shape.dimensions().begin(),
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ class XlaHelpers {
static xla::XlaComputation CreateMaxAndArgMaxComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type);

static xla::XlaComputation CreateMinAndArgMinComputation(
xla::PrimitiveType value_type, xla::PrimitiveType index_type);

// Returns an XLA operation which is a reshape to the expected rank, by
// appending 1s to the major dimension. If offset is greater than zero, 1s
// will be prepened to the minor dimension as well.
Expand Down
70 changes: 70 additions & 0 deletions torch_xla/csrc/ops/cummin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "torch_xla/csrc/ops/cummin.h"

#include <torch/csrc/lazy/core/tensor_util.h>

#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {
namespace {

xla::XlaOp LowerCumMin(xla::XlaOp input, int64_t dim) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp value_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::MaxValue(input_shape.element_type()));
xla::XlaOp index_init_value = xla::ConstantLiteral(
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
xla::XlaOp iota =
xla::Iota(input.builder(),
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32,
input_shape.dimensions()),
dim);
xla::XlaComputation reducer = XlaHelpers::CreateMinAndArgMinComputation(
input_shape.element_type(), xla::PrimitiveType::S32);
return BuildCumulativeComputationWithIndices(
input, iota, dim, reducer, value_init_value, index_init_value);
}

xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
xla::XlaOp values_and_indices = LowerCumMin(operands[0], dim);
return values_and_indices;
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

} // namespace

CumMin::CumMin(const torch::lazy::Value& input, int64_t dim)
: XlaNode(
torch::lazy::OpKind(at::aten::cummin), {input},
[&]() { return NodeOutputShape(input, dim); },
/*num_outputs=*/2, torch::lazy::MHash(dim)),
dim_(dim) {}

torch::lazy::NodePtr CumMin::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<CumMin>(operands.at(0), dim_);
}

XlaOpVector CumMin::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp values_and_indices = LowerCumMin(input, dim_);
return ReturnOps({xla::GetTupleElement(values_and_indices, 0),
xla::GetTupleElement(values_and_indices, 1)},
loctx);
}

std::string CumMin::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", dim=" << dim_;
return ss.str();
}

} // namespace torch_xla
28 changes: 28 additions & 0 deletions torch_xla/csrc/ops/cummin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_
#define XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_

#include <c10/core/ScalarType.h>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class CumMin : public XlaNode {
public:
CumMin(const torch::lazy::Value& input, int64_t dim);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

int64_t dim() const { return dim_; }

private:
int64_t dim_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_
20 changes: 20 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "torch_xla/csrc/ops/convolution_overrideable.h"
#include "torch_xla/csrc/ops/count_nonzero.h"
#include "torch_xla/csrc/ops/cummax.h"
#include "torch_xla/csrc/ops/cummin.h"
#include "torch_xla/csrc/ops/cumprod.h"
#include "torch_xla/csrc/ops/cumsum.h"
#include "torch_xla/csrc/ops/custom_call.h"
Expand Down Expand Up @@ -1314,6 +1315,25 @@ std::tuple<XLATensorPtr, XLATensorPtr> cummax(const XLATensorPtr& input,
return std::make_tuple(t_value, t_index);
}

std::tuple<XLATensorPtr, XLATensorPtr> cummin(const XLATensorPtr& input,
int64_t dim) {
torch::lazy::NodePtr node = torch_xla::MakeNode<CumMin>(
input->GetIrValue(), torch::lazy::GetCanonicalDimensionIndex(
dim, input->shape().get().rank()));
XLATensorPtr t_value = input->CreateFrom(torch::lazy::Value(node, 0),
/*delay_eager_executation=*/true);
XLATensorPtr t_index =
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long,
/*delay_eager_executation=*/true);
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `kthvalue` and in one hlo
std::vector<XLATensorPtr> tensors_to_sync = {t_value, t_index};
graph_executor->ApplyEagerSync(tensors_to_sync);
}
return std::make_tuple(t_value, t_index);
}

XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim,
std::optional<at::ScalarType> dtype) {
int64_t canonical_dim =
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ XLATensorPtr cross(const XLATensorPtr& input, const XLATensorPtr& other,
std::tuple<XLATensorPtr, XLATensorPtr> cummax(const XLATensorPtr& input,
int64_t dim);

// Returns a tuple of the cumulative min of elements and the corresponding
// indices of input in the given dimension.
std::tuple<XLATensorPtr, XLATensorPtr> cummin(const XLATensorPtr& input,
int64_t dim);

// Returns the cumulative product of elements of input in the given dimension.
XLATensorPtr cumprod(const XLATensorPtr& input, int64_t dim,
std::optional<at::ScalarType> dtype);
Expand Down
Loading