Skip to content

Commit

Permalink
Roundtrip unbounded dimension sizes between MHLO and HLO
Browse files Browse the repository at this point in the history
This adds a new capability to roundtrip unbounded (but not unranked!) tensor types between MHLO and HLO. Unbounded dimensions are modelled as dynamic dimensions of `Shape::kUnboundedSize` size. As discussed, the unbounded size is set to `std::numeric_limits<int64_t>::max()`.

In order to enable creation of unbounded shapes, this CL disables validation of `dense_shape_size` in `ShapeUtil::FillNewShape` and `ShapeUtil::ValidateShapeSize` if an unbounded size is found among shape's dimensions.

As discussed, it is expected that there are going to be many locations in the code which will blow up when seeing unbounded shapes. Auditing and addressing them is out of scope of this CL - in this CL, we only introduce the representational capability, as well as roundtripping between MHLO and HLO. Determining which parts of XLA will be adapted to unbounded dynamism and actually adapting them is going to be a long journey.

PiperOrigin-RevId: 580002339
  • Loading branch information
ghpvnist authored and copybara-github committed Nov 9, 2023
1 parent 2870303 commit 9b27adc
Show file tree
Hide file tree
Showing 18 changed files with 197 additions and 33 deletions.
1 change: 1 addition & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ xla_cc_test(
":shape_util",
":test",
":xla_data_proto_cc",
"//xla:status",
"@com_google_absl//absl/hash:hash_testing",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4744,6 +4744,7 @@ xla_cc_test(
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
9 changes: 8 additions & 1 deletion xla/service/hlo_lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TokKind HloLexer::LexToken() {
token_state_.token_start = current_ptr_;

int current_char = GetNextChar();
TokKind tmp;
switch (current_char) {
default:
// [a-zA-Z_]
Expand Down Expand Up @@ -132,7 +133,11 @@ TokKind HloLexer::LexToken() {
current_ptr_++;
return TokKind::kArrow;
}
return LexNumberOrPattern();
tmp = LexNumberOrPattern();
if (tmp == TokKind::kError && current_char == '?') {
return TokKind::kQuestionMark;
}
return tmp;
case '=':
return TokKind::kEqual;
case '<':
Expand Down Expand Up @@ -569,6 +574,8 @@ std::string TokKindToString(TokKind kind) {
return "kColon";
case TokKind::kAsterisk:
return "kAsterisk";
case TokKind::kQuestionMark:
return "kQuestionMark";
case TokKind::kOctothorp:
return "kOctothorp";
case TokKind::kPlus:
Expand Down
15 changes: 8 additions & 7 deletions xla/service/hlo_lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ enum class TokKind {
kError,

// Tokens with no info.
kEqual, // =
kComma, // ,
kColon, // :
kAsterisk, // *
kOctothorp, // #
kPlus, // +
kTilde, // ~
kEqual, // =
kComma, // ,
kColon, // :
kAsterisk, // *
kQuestionMark, // ?
kOctothorp, // #
kPlus, // +
kTilde, // ~
kLsquare,
kRsquare, // [ ]
kLbrace,
Expand Down
15 changes: 11 additions & 4 deletions xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5382,19 +5382,26 @@ bool HloParserImpl::ParseParamList() {
// dimension_sizes ::= '[' dimension_list ']'
// dimension_list
// ::= /*empty*/
// ::= '?'
// ::= <=? int64_t (',' param)*
// param ::= name shape
bool HloParserImpl::ParseDimensionSizes(std::vector<int64_t>* dimension_sizes,
std::vector<bool>* dynamic_dimensions) {
auto parse_and_add_item = [&]() {
int64_t i;
bool is_dynamic = false;
if (lexer_.GetKind() == TokKind::kLeq) {
if (lexer_.GetKind() == TokKind::kQuestionMark) {
i = Shape::kUnboundedSize;
is_dynamic = true;
lexer_.Lex();
}
if (!ParseInt64(&i)) {
return false;
} else {
if (lexer_.GetKind() == TokKind::kLeq) {
is_dynamic = true;
lexer_.Lex();
}
if (!ParseInt64(&i)) {
return false;
}
}
dimension_sizes->push_back(i);
dynamic_dimensions->push_back(is_dynamic);
Expand Down
11 changes: 11 additions & 0 deletions xla/service/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tests/verified_hlo_module.h"
#include "xla/window_util.h"
Expand Down Expand Up @@ -4069,6 +4070,16 @@ TEST_F(HloParserTest, ParseShapeStringR2F32) {
<< "actual: " << ShapeUtil::HumanString(actual);
}

TEST_F(HloParserTest, ParseShapeStringUnbounded) {
std::string shape_string = "f32[?,784]";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected =
ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}

TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
std::string shape_string = "(f32[1572864],s8[5120,1024])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Expand Down
4 changes: 4 additions & 0 deletions xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
if (arity) {
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
}
if (!opts_.allow_unbounded_dynamism && hlo->shape().is_unbounded_dynamic()) {
return InvalidArgument("Unbounded dynamism is disabled for instruction: %s",
hlo->ToString());
}
return OkStatus();
}

Expand Down
8 changes: 8 additions & 0 deletions xla/service/hlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ struct HloVerifierOpts {
return std::move(*this);
}

HloVerifierOpts&& WithAllowUnboundedDynamism(bool allow) {
allow_unbounded_dynamism = allow;
return std::move(*this);
}

bool IsLayoutSensitive() const { return layout_sensitive; }

bool AllowMixedPrecision() const { return allow_mixed_precision; }
Expand Down Expand Up @@ -131,6 +136,9 @@ struct HloVerifierOpts {
// Whether bitcast should have the same size, including all paddings.
bool allow_bitcast_to_have_different_size = false;

// Whether unbounded dynamic sizes should be allowed for shapes.
bool allow_unbounded_dynamism = false;

HloPredicate instruction_can_change_layout;

// Returns a target-specific shape size.
Expand Down
29 changes: 29 additions & 0 deletions xla/service/hlo_verifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -2933,5 +2934,33 @@ ENTRY entry {
TF_ASSERT_OK(status);
}

TEST_F(HloVerifierTest, UnboundedDynamism) {
const char* const hlo = R"(
HloModule Module
ENTRY entry {
ROOT param0 = f32[?,784] parameter(0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), HasSubstr("Unbounded dynamism is disabled"));
}

TEST_F(HloVerifierTest, EnableUnboundedDynamism) {
const char* const hlo = R"(
HloModule Module
ENTRY entry {
ROOT param0 = f32[?,784] parameter(0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
HloVerifier verifier{HloVerifierOpts{}.WithAllowUnboundedDynamism(true)};
auto status = verifier.Run(module.get()).status();
ASSERT_TRUE(status.ok());
}

} // namespace
} // namespace xla
10 changes: 10 additions & 0 deletions xla/shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ bool Shape::is_static() const {
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
}

bool Shape::is_unbounded_dynamic() const {
if (IsTuple() && absl::c_any_of(tuple_shapes_, [](const Shape& subshape) {
return subshape.is_unbounded_dynamic();
})) {
return true;
}
return absl::c_any_of(dimensions_,
[](int64_t dim) { return dim == kUnboundedSize; });
}

void Shape::DeleteDimension(int64_t dim_to_delete) {
CHECK(IsArray());
CHECK_GE(dim_to_delete, 0);
Expand Down
13 changes: 12 additions & 1 deletion xla/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef XLA_SHAPE_H_
#define XLA_SHAPE_H_

#include <cstdint>
#include <limits>
#include <optional>
#include <ostream>
#include <string>
Expand Down Expand Up @@ -91,6 +91,17 @@ class Shape {

bool is_dynamic() const { return !is_static(); }

// Unbounded dynamism.
// If `dimensions(axis) == kUnboundedSize && is_dynamic_dimension(axis)`,
// this means that the axis has unbounded dynamic size.
// The sentinel value for kUnboundedSize is chosen to be exactly the same
// as the sentinel value mlir::ShapedType::kDynamic.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

// Returns true if the shape has one or more dimensions with unbounded sizes.
// Tuple shapes are traversed recursively.
bool is_unbounded_dynamic() const;

// Returns true if the given dimension is dynamically-sized.
bool is_dynamic_dimension(int dimension) const {
return dynamic_dimensions_.at(dimension);
Expand Down
34 changes: 32 additions & 2 deletions xla/shape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class ShapeTest : public ::testing::Test {
ShapeUtil::MakeTupleShape({tuple_, matrix_, token_});
const Shape dynamic_matrix_ =
ShapeUtil::MakeShape(S32, {5, 2}, {true, false});
const Shape unbounded_ =
ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false});
};

TEST_F(ShapeTest, ShapeToFromProto) {
for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_,
tuple_, nested_tuple_, dynamic_matrix_}) {
for (const Shape& shape :
{opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_,
dynamic_matrix_, unbounded_}) {
Shape shape_copy(shape.ToProto());
EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy))
<< shape << " != " << shape_copy;
Expand Down Expand Up @@ -83,6 +86,8 @@ TEST_F(ShapeTest, DynamicShapeToString) {

array_shape.set_dynamic_dimension(2, false);
EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString());

EXPECT_EQ("f32[?,784]", unbounded_.ToString());
}

TEST_F(ShapeTest, EqualityTest) {
Expand Down Expand Up @@ -120,6 +125,28 @@ TEST_F(ShapeTest, IsStatic) {
ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
->set_dynamic_dimension(1, true);
EXPECT_FALSE(dynamic_tuple.is_static());

EXPECT_FALSE(unbounded_.is_static());
}

TEST_F(ShapeTest, IsDynamic) {
EXPECT_FALSE(matrix_.is_dynamic());
EXPECT_FALSE(matrix_.is_unbounded_dynamic());

EXPECT_TRUE(dynamic_matrix_.is_dynamic());
EXPECT_FALSE(dynamic_matrix_.is_unbounded_dynamic());

EXPECT_TRUE(unbounded_.is_dynamic());
EXPECT_TRUE(unbounded_.is_unbounded_dynamic());

Shape unbounded_tuple = tuple_;
EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic());
ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2})
->set_dynamic_dimension(1, true);
EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic());
ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2})
->set_dimensions(1, Shape::kUnboundedSize);
EXPECT_TRUE(unbounded_tuple.is_unbounded_dynamic());
}

TEST_F(ShapeTest, IsDynamicDimension) {
Expand All @@ -133,6 +160,9 @@ TEST_F(ShapeTest, IsDynamicDimension) {
ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2})
->set_dynamic_dimension(1, true);
EXPECT_FALSE(dynamic_tuple.is_static());

EXPECT_TRUE(unbounded_.is_dynamic_dimension(0));
EXPECT_FALSE(unbounded_.is_dynamic_dimension(1));
}

TEST_F(ShapeTest, ProgramShapeToFromProto) {
Expand Down
27 changes: 20 additions & 7 deletions xla/shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,18 @@ Shape MakeTupleShapeImpl(absl::Span<ShapePtrOrRef> shapes) {
const int ndims = dimensions.size();
auto layout = shape->mutable_layout();
auto* minor_to_major = layout->mutable_minor_to_major();
auto is_unbounded_dynamic = absl::c_any_of(
dimensions, [](int64_t dim) { return dim == Shape::kUnboundedSize; });
for (int i = 0; i < ndims; i++) {
const int64_t d = dimensions[i];
if (d < 0) {
if (d < 0 && d != Shape::kUnboundedSize) {
return false;
}
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
if (dense_shape_size < 0) {
return false;
if (!is_unbounded_dynamic) {
dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d);
if (dense_shape_size < 0) {
return false;
}
}

shape->add_dimensions(d);
Expand Down Expand Up @@ -698,9 +702,14 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) {
printer->Append("[");
auto print_one = [&](int i) {
if (shape.is_dynamic_dimension(i)) {
printer->Append("<=");
if (shape.dimensions(i) != Shape::kUnboundedSize) {
printer->Append(StrCat("<=", shape.dimensions(i)));
} else {
printer->Append("?");
}
} else {
printer->Append(shape.dimensions(i));
}
printer->Append(shape.dimensions(i));
};
print_one(0);
for (int i = 1, n = shape.dimensions_size(); i < n; ++i) {
Expand Down Expand Up @@ -926,7 +935,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) {

for (int64_t i = 0; i < shape.rank(); ++i) {
int64_t dimension = shape.dimensions(i);
if (dimension < 0) {
if (dimension < 0 && dimension != Shape::kUnboundedSize) {
return InvalidArgument(
"shape's dimensions must not be < 0; dimension at index %d was %d", i,
dimension);
Expand All @@ -944,6 +953,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) {
return OkStatus();
}

if (shape.is_unbounded_dynamic()) {
return OkStatus();
}

int64_t shape_size = [&]() {
int64_t dense_shape_size = 1;
if (shape.dimensions().empty()) {
Expand Down
Loading

0 comments on commit 9b27adc

Please sign in to comment.