-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower cummin op #8565
base: master
Are you sure you want to change the base?
Lower cummin op #8565
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -172,6 +172,7 @@ supported: | |
- count_nonzero.dim_IntList | ||
- cross | ||
- cummax | ||
- cummin | ||
- cumprod | ||
- cumsum | ||
- detach_copy | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,6 @@ def get_allowed_ops_map( | |
AllowedOpInfoEntry('complex'), | ||
AllowedOpInfoEntry('copysign'), | ||
AllowedOpInfoEntry('cross'), | ||
AllowedOpInfoEntry('cummin'), | ||
AllowedOpInfoEntry('deg2rad'), | ||
AllowedOpInfoEntry('div', 'no_rounding_mode'), | ||
AllowedOpInfoEntry('div', 'trunc_rounding'), | ||
|
@@ -289,6 +288,7 @@ def get_allowed_ops_map( | |
# AllowedOpInfoEntry('cosh'), | ||
# AllowedOpInfoEntry('cov'), | ||
# AllowedOpInfoEntry('cummax'), | ||
# AllowedOpInfoEntry('cummin'), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving the discussion about 0-sized dimensions here, so as not to flood the other thread. I'm asking for that, because I think this is a BC-breaking change. Before your PR, That said, since other operations (e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am not sure, the error suggests it might have come from the XLA backend: https://github.com/openxla/xla/blob/e795171be9897e1356b1d98588e4fe784e2fc1bb/xla/service/shape_inference.cc#L202. will need more time to figure out a fix.
Agree. It also applies to core ops like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the XLA backend doesn't work with 0-sized dimensions, maybe we could catch this while tracing (since the shapes should be static). Then, return 2 0-sized tensors. That said, I believe this could be left for a future PR (maybe also solving the other ops that have the same issue). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Is there anything else you'd want me to add in this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think it would be nice to have at least one test in the Python side, as @yaochengji suggested. |
||
# AllowedOpInfoEntry('cumsum'), | ||
# AllowedOpInfoEntry('cumprod'), | ||
# AllowedOpInfoEntry('diff'), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#include "torch_xla/csrc/ops/cummin.h" | ||
|
||
#include <torch/csrc/lazy/core/tensor_util.h> | ||
|
||
#include "torch_xla/csrc/convert_ops.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/infer_output_shape.h" | ||
#include "torch_xla/csrc/reduction.h" | ||
#include "torch_xla/csrc/shape_helper.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
#include "torch_xla/csrc/torch_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
xla::XlaOp LowerCumMin(xla::XlaOp input, int64_t dim) { | ||
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); | ||
xla::XlaOp value_init_value = xla::ConstantLiteral( | ||
input.builder(), xla::LiteralUtil::MaxValue(input_shape.element_type())); | ||
xla::XlaOp index_init_value = xla::ConstantLiteral( | ||
input.builder(), xla::LiteralUtil::Zero(xla::PrimitiveType::S32)); | ||
xla::XlaOp iota = | ||
xla::Iota(input.builder(), | ||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, | ||
input_shape.dimensions()), | ||
dim); | ||
xla::XlaComputation reducer = XlaHelpers::CreateMinAndArgMinComputation( | ||
input_shape.element_type(), xla::PrimitiveType::S32); | ||
return BuildCumulativeComputationWithIndices( | ||
input, iota, dim, reducer, value_init_value, index_init_value); | ||
} | ||
|
||
xla::Shape NodeOutputShape(const torch::lazy::Value& input, int64_t dim) { | ||
auto lower_for_shape_fn = | ||
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
xla::XlaOp values_and_indices = LowerCumMin(operands[0], dim); | ||
return values_and_indices; | ||
}; | ||
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); | ||
} | ||
|
||
} // namespace | ||
|
||
CumMin::CumMin(const torch::lazy::Value& input, int64_t dim) | ||
: XlaNode( | ||
torch::lazy::OpKind(at::aten::cummin), {input}, | ||
[&]() { return NodeOutputShape(input, dim); }, | ||
/*num_outputs=*/2, torch::lazy::MHash(dim)), | ||
dim_(dim) {} | ||
|
||
torch::lazy::NodePtr CumMin::Clone(torch::lazy::OpList operands) const { | ||
return torch_xla::MakeNode<CumMin>(operands.at(0), dim_); | ||
} | ||
|
||
XlaOpVector CumMin::Lower(LoweringContext* loctx) const { | ||
xla::XlaOp input = loctx->GetOutputOp(operand(0)); | ||
xla::XlaOp values_and_indices = LowerCumMin(input, dim_); | ||
return ReturnOps({xla::GetTupleElement(values_and_indices, 0), | ||
xla::GetTupleElement(values_and_indices, 1)}, | ||
loctx); | ||
} | ||
|
||
std::string CumMin::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", dim=" << dim_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_ | ||
|
||
#include <c10/core/ScalarType.h> | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
class CumMin : public XlaNode { | ||
public: | ||
CumMin(const torch::lazy::Value& input, int64_t dim); | ||
|
||
std::string ToString() const override; | ||
|
||
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
||
int64_t dim() const { return dim_; } | ||
|
||
private: | ||
int64_t dim_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_CUMMIN_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zyy-martin Thanks for your contributing.
Could you add a test for the op execution result's correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this same note, could you add a test where
dim
is a 0-sized dimension?Since this is using the same lowering as
cummax
, it could also have the same problems. ref: #8610There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has the same problem as
cummax
, as well as old ops such ascumprod
andcumsum
. I assume that was the reason those tests were excluded in the first place.We can consider modify the op tests to have the option to exclude scalars/0-dimension values to cover other test cases. The C++ tests in this PR have verified the correctness though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now the SampleInput includes scalers by default here https://github.com/pytorch/pytorch/blob/dddf52b1b91d473e249829dad6c705c13624b35f/torch/testing/_internal/common_methods_invocations.py#L6942. I believe we can override the sample function in our test_ops.py to exclude the zero-dimension test cases. As for why the current implementation does not support scalar, i think we will need more investigation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the test suite only support 0-sized dim, we can write simple tests from scratch to compare the result of torch/xla with native pytorch on cpu.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @yaochengji here to include a test for result correctness.
Otherwise, LGTM.