diff --git a/xla/BUILD b/xla/BUILD index 74f37b5f6537f9..04befaee8f0f57 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -466,6 +466,7 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", + "//xla:status", "@com_google_absl//absl/hash:hash_testing", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", diff --git a/xla/service/BUILD b/xla/service/BUILD index 5ca73eacc113dc..d3101d24e03fb3 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4744,6 +4744,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/hlo_lexer.cc b/xla/service/hlo_lexer.cc index 0e53e6f844e99f..bd516129caa7a4 100644 --- a/xla/service/hlo_lexer.cc +++ b/xla/service/hlo_lexer.cc @@ -96,6 +96,7 @@ TokKind HloLexer::LexToken() { token_state_.token_start = current_ptr_; int current_char = GetNextChar(); + TokKind tmp; switch (current_char) { default: // [a-zA-Z_] @@ -132,7 +133,11 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexNumberOrPattern(); + tmp = LexNumberOrPattern(); + if (tmp == TokKind::kError && current_char == '?') { + return TokKind::kQuestionMark; + } + return tmp; case '=': return TokKind::kEqual; case '<': @@ -569,6 +574,8 @@ std::string TokKindToString(TokKind kind) { return "kColon"; case TokKind::kAsterisk: return "kAsterisk"; + case TokKind::kQuestionMark: + return "kQuestionMark"; case TokKind::kOctothorp: return "kOctothorp"; case TokKind::kPlus: diff --git a/xla/service/hlo_lexer.h b/xla/service/hlo_lexer.h index 031ec1ae295330..5681818c07162c 100644 --- a/xla/service/hlo_lexer.h +++ b/xla/service/hlo_lexer.h @@ -39,13 +39,14 @@ enum class TokKind { kError, // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kAsterisk, // * - kOctothorp, // # - kPlus, // + - kTilde, // ~ + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kQuestionMark, // ? + kOctothorp, // # + kPlus, // + + kTilde, // ~ kLsquare, kRsquare, // [ ] kLbrace, diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 132d14aa876ff5..587997713b22a2 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -5382,6 +5382,7 @@ bool HloParserImpl::ParseParamList() { // dimension_sizes ::= '[' dimension_list ']' // dimension_list // ::= /*empty*/ +// ::= '?' // ::= <=? int64_t (',' param)* // param ::= name shape bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, @@ -5389,12 +5390,18 @@ bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, auto parse_and_add_item = [&]() { int64_t i; bool is_dynamic = false; - if (lexer_.GetKind() == TokKind::kLeq) { + if (lexer_.GetKind() == TokKind::kQuestionMark) { + i = Shape::kUnboundedSize; is_dynamic = true; lexer_.Lex(); - } - if (!ParseInt64(&i)) { - return false; + } else { + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } + if (!ParseInt64(&i)) { + return false; + } } dimension_sizes->push_back(i); dynamic_dimensions->push_back(is_dynamic); diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index 4baf3be5bff223..34205f9524d542 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" #include "xla/window_util.h" @@ -4069,6 +4070,16 @@ TEST_F(HloParserTest, ParseShapeStringR2F32) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, ParseShapeStringUnbounded) { + std::string shape_string = "f32[?,784]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { std::string shape_string = "(f32[1572864],s8[5120,1024])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index ab4eb980047004..17bd5c3d1512ee 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -160,6 +160,10 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) { if (arity) { TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); } + if (!opts_.allow_unbounded_dynamism && hlo->shape().is_unbounded_dynamic()) { + return InvalidArgument("Unbounded dynamism is disabled for instruction: %s", + hlo->ToString()); + } return OkStatus(); } diff --git a/xla/service/hlo_verifier.h b/xla/service/hlo_verifier.h index 813b3ba30d01b1..29744af10982bb 100644 --- a/xla/service/hlo_verifier.h +++ b/xla/service/hlo_verifier.h @@ -91,6 +91,11 @@ struct HloVerifierOpts { return std::move(*this); } + HloVerifierOpts&& WithAllowUnboundedDynamism(bool allow) { + allow_unbounded_dynamism = allow; + return std::move(*this); + } + bool IsLayoutSensitive() const { return layout_sensitive; } bool AllowMixedPrecision() const { return allow_mixed_precision; } @@ -131,6 +136,9 @@ struct HloVerifierOpts { // Whether bitcast should have the same size, including all paddings. bool allow_bitcast_to_have_different_size = false; + // Whether unbounded dynamic sizes should be allowed for shapes. + bool allow_unbounded_dynamism = false; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index fe36e7318ad956..d6faca7cc585ea 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -2933,5 +2934,33 @@ ENTRY entry { TF_ASSERT_OK(status); } +TEST_F(HloVerifierTest, UnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), HasSubstr("Unbounded dynamism is disabled")); +} + +TEST_F(HloVerifierTest, EnableUnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + HloVerifier verifier{HloVerifierOpts{}.WithAllowUnboundedDynamism(true)}; + auto status = verifier.Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/xla/shape.cc b/xla/shape.cc index 1909291788aa92..0ad4897f320e9f 100644 --- a/xla/shape.cc +++ b/xla/shape.cc @@ -137,6 +137,16 @@ bool Shape::is_static() const { return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } +bool Shape::is_unbounded_dynamic() const { + if (IsTuple() && absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + return subshape.is_unbounded_dynamic(); + })) { + return true; + } + return absl::c_any_of(dimensions_, + [](int64_t dim) { return dim == kUnboundedSize; }); +} + void Shape::DeleteDimension(int64_t dim_to_delete) { CHECK(IsArray()); CHECK_GE(dim_to_delete, 0); diff --git a/xla/shape.h b/xla/shape.h index 214b87a0f3b505..9386fc18043ac7 100644 --- a/xla/shape.h +++ b/xla/shape.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SHAPE_H_ #define XLA_SHAPE_H_ -#include +#include #include #include #include @@ -91,6 +91,17 @@ class Shape { bool is_dynamic() const { return !is_static(); } + // Unbounded dynamism. + // If `dimensions(axis) == kUnboundedSize && is_dynamic_dimension(axis)`, + // this means that the axis has unbounded dynamic size. + // The sentinel value for kUnboundedSize is chosen to be exactly the same + // as the sentinel value mlir::ShapedType::kDynamic. + static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + + // Returns true if the shape has one or more dimensions with unbounded sizes. + // Tuple shapes are traversed recursively. + bool is_unbounded_dynamic() const; + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/xla/shape_test.cc b/xla/shape_test.cc index d691ee64b17079..322f02e4773f67 100644 --- a/xla/shape_test.cc +++ b/xla/shape_test.cc @@ -41,11 +41,14 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); const Shape dynamic_matrix_ = ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); + const Shape unbounded_ = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, - tuple_, nested_tuple_, dynamic_matrix_}) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_, + dynamic_matrix_, unbounded_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -83,6 +86,8 @@ TEST_F(ShapeTest, DynamicShapeToString) { array_shape.set_dynamic_dimension(2, false); EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); + + EXPECT_EQ("f32[?,784]", unbounded_.ToString()); } TEST_F(ShapeTest, EqualityTest) { @@ -120,6 +125,28 @@ TEST_F(ShapeTest, IsStatic) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_FALSE(unbounded_.is_static()); +} + +TEST_F(ShapeTest, IsDynamic) { + EXPECT_FALSE(matrix_.is_dynamic()); + EXPECT_FALSE(matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(dynamic_matrix_.is_dynamic()); + EXPECT_FALSE(dynamic_matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(unbounded_.is_dynamic()); + EXPECT_TRUE(unbounded_.is_unbounded_dynamic()); + + Shape unbounded_tuple = tuple_; + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dimensions(1, Shape::kUnboundedSize); + EXPECT_TRUE(unbounded_tuple.is_unbounded_dynamic()); } TEST_F(ShapeTest, IsDynamicDimension) { @@ -133,6 +160,9 @@ TEST_F(ShapeTest, IsDynamicDimension) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_TRUE(unbounded_.is_dynamic_dimension(0)); + EXPECT_FALSE(unbounded_.is_dynamic_dimension(1)); } TEST_F(ShapeTest, ProgramShapeToFromProto) { diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 55882a59e7cbdc..004fd94ea5fcdc 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -248,14 +248,18 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { const int ndims = dimensions.size(); auto layout = shape->mutable_layout(); auto* minor_to_major = layout->mutable_minor_to_major(); + auto is_unbounded_dynamic = absl::c_any_of( + dimensions, [](int64_t dim) { return dim == Shape::kUnboundedSize; }); for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; - if (d < 0) { + if (d < 0 && d != Shape::kUnboundedSize) { return false; } - dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); - if (dense_shape_size < 0) { - return false; + if (!is_unbounded_dynamic) { + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); + if (dense_shape_size < 0) { + return false; + } } shape->add_dimensions(d); @@ -698,9 +702,14 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("["); auto print_one = [&](int i) { if (shape.is_dynamic_dimension(i)) { - printer->Append("<="); + if (shape.dimensions(i) != Shape::kUnboundedSize) { + printer->Append(StrCat("<=", shape.dimensions(i))); + } else { + printer->Append("?"); + } + } else { + printer->Append(shape.dimensions(i)); } - printer->Append(shape.dimensions(i)); }; print_one(0); for (int i = 1, n = shape.dimensions_size(); i < n; ++i) { @@ -926,7 +935,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dimension = shape.dimensions(i); - if (dimension < 0) { + if (dimension < 0 && dimension != Shape::kUnboundedSize) { return InvalidArgument( "shape's dimensions must not be < 0; dimension at index %d was %d", i, dimension); @@ -944,6 +953,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return OkStatus(); } + if (shape.is_unbounded_dynamic()) { + return OkStatus(); + } + int64_t shape_size = [&]() { int64_t dense_shape_size = 1; if (shape.dimensions().empty()) { diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.h b/xla/translate/hlo_to_mhlo/hlo_utils.h index 42682a251385ae..275f59a43ebdba 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -59,22 +59,24 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, ConvertPrimitiveTypeToMLIRType(xla_ty.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); - bool is_dynamic = false; + bool is_bounded_dynamic = false; int64_t rank = xla_ty.rank(); llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); for (int64_t dim = 0; dim < rank; ++dim) { int64_t dim_size = xla_ty.dimensions(dim); if (xla_ty.is_dynamic_dimension(dim)) { - bounds[dim] = dim_size; - is_dynamic = true; + if (dim_size != Shape::kUnboundedSize) { + bounds[dim] = dim_size; + is_bounded_dynamic = true; + } } else { shape[dim] = dim_size; } } using mlir::mhlo::TypeExtensionsAttr; mlir::Attribute encoding; - if (is_dynamic) { + if (is_bounded_dynamic) { encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); } @@ -89,7 +91,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, if (xla_ty.has_layout()) { auto layout = xla_ty.layout(); if (LayoutUtil::IsSparse(layout)) { - if (is_dynamic) + if (is_bounded_dynamic) return Unimplemented( "MHLO doesn't support bounded dynamic shapes for sparse tensors"); llvm::SmallVector dlts; diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 6e8ef58022478a..8344c204ab4f89 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -1838,3 +1838,12 @@ add { %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) } + +// CHECK-LABEL: func.func private @unbounded(%arg0: tensor) -> tensor { +// CHECK-NEXT: [[VAL0:%.*]] = mhlo.abs %arg0 : tensor +// CHECK-NEXT: return [[VAL0]] : tensor +// CHECK-NEXT: } +%unbounded (Arg_0.1: f32[?,784]) -> f32[?,784] { + %Arg_0.1 = f32[?,784] parameter(0) + ROOT %abs.2 = f32[?,784] abs(f32[?,784] %Arg_0.1) +} diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 7cd4212027046d..9bcb03e1ea747a 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -3048,3 +3048,17 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x func.func @main(%arg0: tensor {mhlo.parameter_replication = [true]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [false, true]}) -> tensor { return %arg0 : tensor } + +// ----- + +func.func @main(%operand: tensor) -> tensor { + %0 = mhlo.abs %operand : tensor + func.return %0 : tensor +} + +// CHECK: HloModule {{.*}}, entry_computation_layout={(f32[?,784]{1,0})->f32[?,784]{1,0}} +// CHECK-EMPTY: +// CHECK-NEXT: ENTRY {{.*}} ([[ARG0:.*]]: f32[?,784]) -> f32[?,784] { +// CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) +// CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} +// CHECK-NEXT: } diff --git a/xla/translate/mhlo_to_hlo/type_to_shape.cc b/xla/translate/mhlo_to_hlo/type_to_shape.cc index 8ccc406b756828..27fbcb2ad60e85 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -178,12 +178,11 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); for (int64_t dim = 0; dim < rank; ++dim) { - // Only fully static shapes are supported. - // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { - if (bounds[dim] == ShapedType::kDynamic) return {}; - shape[dim] = bounds[dim]; + shape[dim] = bounds[dim] != ShapedType::kDynamic + ? bounds[dim] + : Shape::kUnboundedSize; is_dynamic[dim] = true; } else { if (bounds[dim] != ShapedType::kDynamic) return {}; diff --git a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc index 37d82730cb881f..e38dbc355d0426 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -138,8 +138,15 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}, {true, false}) .ToProto())); - // Shape cannot represent dynamic shapes. - // TODO(b/115638799): Update once Shape can support dynamic shapes. + EXPECT_THAT( + TypeToShape(RankedTensorType::get({mlir::ShapedType::kDynamic, 784}, + b.getF32Type())) + .ToProto(), + EqualsProto(ShapeUtil::MakeShape(PrimitiveType::F32, + {Shape::kUnboundedSize, 784}, + {true, false}) + .ToProto())); + EXPECT_THAT(TypeToShape(UnrankedTensorType::get(b.getF32Type())).ToProto(), EqualsProto(Shape().ToProto()));