From 9b27adc83179d0ae8900e09cd9f3fe2f41c0e63c Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Mon, 6 Nov 2023 17:06:05 -0800 Subject: [PATCH] Roundtrip unbounded dimension sizes between MHLO and HLO This adds a new capability to roundtrip unbounded (but not unranked!) tensor types between MHLO and HLO. Unbounded dimensions are modelled as dynamic dimensions of `Shape::kUnboundedSize` size. As discussed, the unbounded size is set to `std::numeric_limits::max()`. In order to enable creation of unbounded shapes, this CL disables validation of `dense_shape_size` in `ShapeUtil::FillNewShape` and `ShapeUtil::ValidateShapeSize` if an unbounded size is found among shape's dimensions. As discussed, it is expected that there are going to be many locations in the code which will blow up when seeing unbounded shapes. Auditing and addressing them is out of scope of this CL - in this CL, we only introduce the representational capability, as well as roundtripping between MHLO and HLO. Determining which parts of XLA will be adapted to unbounded dynamism and actually adapting them is going to be a long journey. PiperOrigin-RevId: 580002339 --- xla/BUILD | 1 + xla/service/BUILD | 1 + xla/service/hlo_lexer.cc | 9 ++++- xla/service/hlo_lexer.h | 15 ++++---- xla/service/hlo_parser.cc | 15 +++++--- xla/service/hlo_parser_test.cc | 11 ++++++ xla/service/hlo_verifier.cc | 4 +++ xla/service/hlo_verifier.h | 8 +++++ xla/service/hlo_verifier_test.cc | 29 ++++++++++++++++ xla/shape.cc | 10 ++++++ xla/shape.h | 13 ++++++- xla/shape_test.cc | 34 +++++++++++++++++-- xla/shape_util.cc | 27 +++++++++++---- xla/translate/hlo_to_mhlo/hlo_utils.h | 12 ++++--- xla/translate/hlo_to_mhlo/tests/import.hlotxt | 9 +++++ xla/translate/mhlo_to_hlo/tests/export.mlir | 14 ++++++++ xla/translate/mhlo_to_hlo/type_to_shape.cc | 7 ++-- .../mhlo_to_hlo/type_to_shape_test.cc | 11 ++++-- 18 files changed, 197 insertions(+), 33 deletions(-) diff --git a/xla/BUILD b/xla/BUILD index 74f37b5f6537f9..04befaee8f0f57 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -466,6 +466,7 @@ xla_cc_test( ":shape_util", ":test", ":xla_data_proto_cc", + "//xla:status", "@com_google_absl//absl/hash:hash_testing", "@tsl//tsl/platform:test_benchmark", "@tsl//tsl/platform:test_main", diff --git a/xla/service/BUILD b/xla/service/BUILD index 5ca73eacc113dc..d3101d24e03fb3 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4744,6 +4744,7 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/hlo_lexer.cc b/xla/service/hlo_lexer.cc index 0e53e6f844e99f..bd516129caa7a4 100644 --- a/xla/service/hlo_lexer.cc +++ b/xla/service/hlo_lexer.cc @@ -96,6 +96,7 @@ TokKind HloLexer::LexToken() { token_state_.token_start = current_ptr_; int current_char = GetNextChar(); + TokKind tmp; switch (current_char) { default: // [a-zA-Z_] @@ -132,7 +133,11 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexNumberOrPattern(); + tmp = LexNumberOrPattern(); + if (tmp == TokKind::kError && current_char == '?') { + return TokKind::kQuestionMark; + } + return tmp; case '=': return TokKind::kEqual; case '<': @@ -569,6 +574,8 @@ std::string TokKindToString(TokKind kind) { return "kColon"; case TokKind::kAsterisk: return "kAsterisk"; + case TokKind::kQuestionMark: + return "kQuestionMark"; case TokKind::kOctothorp: return "kOctothorp"; case TokKind::kPlus: diff --git a/xla/service/hlo_lexer.h b/xla/service/hlo_lexer.h index 031ec1ae295330..5681818c07162c 100644 --- a/xla/service/hlo_lexer.h +++ b/xla/service/hlo_lexer.h @@ -39,13 +39,14 @@ enum class TokKind { kError, // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kAsterisk, // * - kOctothorp, // # - kPlus, // + - kTilde, // ~ + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kQuestionMark, // ? + kOctothorp, // # + kPlus, // + + kTilde, // ~ kLsquare, kRsquare, // [ ] kLbrace, diff --git a/xla/service/hlo_parser.cc b/xla/service/hlo_parser.cc index 132d14aa876ff5..587997713b22a2 100644 --- a/xla/service/hlo_parser.cc +++ b/xla/service/hlo_parser.cc @@ -5382,6 +5382,7 @@ bool HloParserImpl::ParseParamList() { // dimension_sizes ::= '[' dimension_list ']' // dimension_list // ::= /*empty*/ +// ::= '?' // ::= <=? int64_t (',' param)* // param ::= name shape bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, @@ -5389,12 +5390,18 @@ bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, auto parse_and_add_item = [&]() { int64_t i; bool is_dynamic = false; - if (lexer_.GetKind() == TokKind::kLeq) { + if (lexer_.GetKind() == TokKind::kQuestionMark) { + i = Shape::kUnboundedSize; is_dynamic = true; lexer_.Lex(); - } - if (!ParseInt64(&i)) { - return false; + } else { + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } + if (!ParseInt64(&i)) { + return false; + } } dimension_sizes->push_back(i); dynamic_dimensions->push_back(is_dynamic); diff --git a/xla/service/hlo_parser_test.cc b/xla/service/hlo_parser_test.cc index 4baf3be5bff223..34205f9524d542 100644 --- a/xla/service/hlo_parser_test.cc +++ b/xla/service/hlo_parser_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" #include "xla/window_util.h" @@ -4069,6 +4070,16 @@ TEST_F(HloParserTest, ParseShapeStringR2F32) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, ParseShapeStringUnbounded) { + std::string shape_string = "f32[?,784]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { std::string shape_string = "(f32[1572864],s8[5120,1024])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index ab4eb980047004..17bd5c3d1512ee 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -160,6 +160,10 @@ Status ShapeVerifier::Preprocess(HloInstruction* hlo) { if (arity) { TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity)); } + if (!opts_.allow_unbounded_dynamism && hlo->shape().is_unbounded_dynamic()) { + return InvalidArgument("Unbounded dynamism is disabled for instruction: %s", + hlo->ToString()); + } return OkStatus(); } diff --git a/xla/service/hlo_verifier.h b/xla/service/hlo_verifier.h index 813b3ba30d01b1..29744af10982bb 100644 --- a/xla/service/hlo_verifier.h +++ b/xla/service/hlo_verifier.h @@ -91,6 +91,11 @@ struct HloVerifierOpts { return std::move(*this); } + HloVerifierOpts&& WithAllowUnboundedDynamism(bool allow) { + allow_unbounded_dynamism = allow; + return std::move(*this); + } + bool IsLayoutSensitive() const { return layout_sensitive; } bool AllowMixedPrecision() const { return allow_mixed_precision; } @@ -131,6 +136,9 @@ struct HloVerifierOpts { // Whether bitcast should have the same size, including all paddings. bool allow_bitcast_to_have_different_size = false; + // Whether unbounded dynamic sizes should be allowed for shapes. + bool allow_unbounded_dynamism = false; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index fe36e7318ad956..d6faca7cc585ea 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -2933,5 +2934,33 @@ ENTRY entry { TF_ASSERT_OK(status); } +TEST_F(HloVerifierTest, UnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), HasSubstr("Unbounded dynamism is disabled")); +} + +TEST_F(HloVerifierTest, EnableUnboundedDynamism) { + const char* const hlo = R"( + HloModule Module + + ENTRY entry { + ROOT param0 = f32[?,784] parameter(0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + HloVerifier verifier{HloVerifierOpts{}.WithAllowUnboundedDynamism(true)}; + auto status = verifier.Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/xla/shape.cc b/xla/shape.cc index 1909291788aa92..0ad4897f320e9f 100644 --- a/xla/shape.cc +++ b/xla/shape.cc @@ -137,6 +137,16 @@ bool Shape::is_static() const { return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); } +bool Shape::is_unbounded_dynamic() const { + if (IsTuple() && absl::c_any_of(tuple_shapes_, [](const Shape& subshape) { + return subshape.is_unbounded_dynamic(); + })) { + return true; + } + return absl::c_any_of(dimensions_, + [](int64_t dim) { return dim == kUnboundedSize; }); +} + void Shape::DeleteDimension(int64_t dim_to_delete) { CHECK(IsArray()); CHECK_GE(dim_to_delete, 0); diff --git a/xla/shape.h b/xla/shape.h index 214b87a0f3b505..9386fc18043ac7 100644 --- a/xla/shape.h +++ b/xla/shape.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SHAPE_H_ #define XLA_SHAPE_H_ -#include +#include #include #include #include @@ -91,6 +91,17 @@ class Shape { bool is_dynamic() const { return !is_static(); } + // Unbounded dynamism. + // If `dimensions(axis) == kUnboundedSize && is_dynamic_dimension(axis)`, + // this means that the axis has unbounded dynamic size. + // The sentinel value for kUnboundedSize is chosen to be exactly the same + // as the sentinel value mlir::ShapedType::kDynamic. + static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + + // Returns true if the shape has one or more dimensions with unbounded sizes. + // Tuple shapes are traversed recursively. + bool is_unbounded_dynamic() const; + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/xla/shape_test.cc b/xla/shape_test.cc index d691ee64b17079..322f02e4773f67 100644 --- a/xla/shape_test.cc +++ b/xla/shape_test.cc @@ -41,11 +41,14 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); const Shape dynamic_matrix_ = ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); + const Shape unbounded_ = + ShapeUtil::MakeShape(F32, {Shape::kUnboundedSize, 784}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, - tuple_, nested_tuple_, dynamic_matrix_}) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_, + dynamic_matrix_, unbounded_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -83,6 +86,8 @@ TEST_F(ShapeTest, DynamicShapeToString) { array_shape.set_dynamic_dimension(2, false); EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); + + EXPECT_EQ("f32[?,784]", unbounded_.ToString()); } TEST_F(ShapeTest, EqualityTest) { @@ -120,6 +125,28 @@ TEST_F(ShapeTest, IsStatic) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_FALSE(unbounded_.is_static()); +} + +TEST_F(ShapeTest, IsDynamic) { + EXPECT_FALSE(matrix_.is_dynamic()); + EXPECT_FALSE(matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(dynamic_matrix_.is_dynamic()); + EXPECT_FALSE(dynamic_matrix_.is_unbounded_dynamic()); + + EXPECT_TRUE(unbounded_.is_dynamic()); + EXPECT_TRUE(unbounded_.is_unbounded_dynamic()); + + Shape unbounded_tuple = tuple_; + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(unbounded_tuple.is_unbounded_dynamic()); + ShapeUtil::GetMutableSubshape(&unbounded_tuple, {2}) + ->set_dimensions(1, Shape::kUnboundedSize); + EXPECT_TRUE(unbounded_tuple.is_unbounded_dynamic()); } TEST_F(ShapeTest, IsDynamicDimension) { @@ -133,6 +160,9 @@ TEST_F(ShapeTest, IsDynamicDimension) { ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) ->set_dynamic_dimension(1, true); EXPECT_FALSE(dynamic_tuple.is_static()); + + EXPECT_TRUE(unbounded_.is_dynamic_dimension(0)); + EXPECT_FALSE(unbounded_.is_dynamic_dimension(1)); } TEST_F(ShapeTest, ProgramShapeToFromProto) { diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 55882a59e7cbdc..004fd94ea5fcdc 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -248,14 +248,18 @@ Shape MakeTupleShapeImpl(absl::Span shapes) { const int ndims = dimensions.size(); auto layout = shape->mutable_layout(); auto* minor_to_major = layout->mutable_minor_to_major(); + auto is_unbounded_dynamic = absl::c_any_of( + dimensions, [](int64_t dim) { return dim == Shape::kUnboundedSize; }); for (int i = 0; i < ndims; i++) { const int64_t d = dimensions[i]; - if (d < 0) { + if (d < 0 && d != Shape::kUnboundedSize) { return false; } - dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); - if (dense_shape_size < 0) { - return false; + if (!is_unbounded_dynamic) { + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, d); + if (dense_shape_size < 0) { + return false; + } } shape->add_dimensions(d); @@ -698,9 +702,14 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { printer->Append("["); auto print_one = [&](int i) { if (shape.is_dynamic_dimension(i)) { - printer->Append("<="); + if (shape.dimensions(i) != Shape::kUnboundedSize) { + printer->Append(StrCat("<=", shape.dimensions(i))); + } else { + printer->Append("?"); + } + } else { + printer->Append(shape.dimensions(i)); } - printer->Append(shape.dimensions(i)); }; print_one(0); for (int i = 1, n = shape.dimensions_size(); i < n; ++i) { @@ -926,7 +935,7 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { for (int64_t i = 0; i < shape.rank(); ++i) { int64_t dimension = shape.dimensions(i); - if (dimension < 0) { + if (dimension < 0 && dimension != Shape::kUnboundedSize) { return InvalidArgument( "shape's dimensions must not be < 0; dimension at index %d was %d", i, dimension); @@ -944,6 +953,10 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { return OkStatus(); } + if (shape.is_unbounded_dynamic()) { + return OkStatus(); + } + int64_t shape_size = [&]() { int64_t dense_shape_size = 1; if (shape.dimensions().empty()) { diff --git a/xla/translate/hlo_to_mhlo/hlo_utils.h b/xla/translate/hlo_to_mhlo/hlo_utils.h index 42682a251385ae..275f59a43ebdba 100644 --- a/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -59,22 +59,24 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, ConvertPrimitiveTypeToMLIRType(xla_ty.element_type(), builder); if (!element_type_or.ok()) return element_type_or.status(); - bool is_dynamic = false; + bool is_bounded_dynamic = false; int64_t rank = xla_ty.rank(); llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); for (int64_t dim = 0; dim < rank; ++dim) { int64_t dim_size = xla_ty.dimensions(dim); if (xla_ty.is_dynamic_dimension(dim)) { - bounds[dim] = dim_size; - is_dynamic = true; + if (dim_size != Shape::kUnboundedSize) { + bounds[dim] = dim_size; + is_bounded_dynamic = true; + } } else { shape[dim] = dim_size; } } using mlir::mhlo::TypeExtensionsAttr; mlir::Attribute encoding; - if (is_dynamic) { + if (is_bounded_dynamic) { encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); } @@ -89,7 +91,7 @@ static StatusOr ConvertTensorShapeToType(const Shape& xla_ty, if (xla_ty.has_layout()) { auto layout = xla_ty.layout(); if (LayoutUtil::IsSparse(layout)) { - if (is_dynamic) + if (is_bounded_dynamic) return Unimplemented( "MHLO doesn't support bounded dynamic shapes for sparse tensors"); llvm::SmallVector dlts; diff --git a/xla/translate/hlo_to_mhlo/tests/import.hlotxt b/xla/translate/hlo_to_mhlo/tests/import.hlotxt index 6e8ef58022478a..8344c204ab4f89 100644 --- a/xla/translate/hlo_to_mhlo/tests/import.hlotxt +++ b/xla/translate/hlo_to_mhlo/tests/import.hlotxt @@ -1838,3 +1838,12 @@ add { %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) } + +// CHECK-LABEL: func.func private @unbounded(%arg0: tensor) -> tensor { +// CHECK-NEXT: [[VAL0:%.*]] = mhlo.abs %arg0 : tensor +// CHECK-NEXT: return [[VAL0]] : tensor +// CHECK-NEXT: } +%unbounded (Arg_0.1: f32[?,784]) -> f32[?,784] { + %Arg_0.1 = f32[?,784] parameter(0) + ROOT %abs.2 = f32[?,784] abs(f32[?,784] %Arg_0.1) +} diff --git a/xla/translate/mhlo_to_hlo/tests/export.mlir b/xla/translate/mhlo_to_hlo/tests/export.mlir index 7cd4212027046d..9bcb03e1ea747a 100644 --- a/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -3048,3 +3048,17 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x func.func @main(%arg0: tensor {mhlo.parameter_replication = [true]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [false, true]}) -> tensor { return %arg0 : tensor } + +// ----- + +func.func @main(%operand: tensor) -> tensor { + %0 = mhlo.abs %operand : tensor + func.return %0 : tensor +} + +// CHECK: HloModule {{.*}}, entry_computation_layout={(f32[?,784]{1,0})->f32[?,784]{1,0}} +// CHECK-EMPTY: +// CHECK-NEXT: ENTRY {{.*}} ([[ARG0:.*]]: f32[?,784]) -> f32[?,784] { +// CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) +// CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} +// CHECK-NEXT: } diff --git a/xla/translate/mhlo_to_hlo/type_to_shape.cc b/xla/translate/mhlo_to_hlo/type_to_shape.cc index 8ccc406b756828..27fbcb2ad60e85 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape.cc @@ -178,12 +178,11 @@ Shape TypeToShape(mlir::Type type) { llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); std::vector is_dynamic(rank, false); for (int64_t dim = 0; dim < rank; ++dim) { - // Only fully static shapes are supported. - // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. int64_t size = t.getDimSize(dim); if (size == ShapedType::kDynamic) { - if (bounds[dim] == ShapedType::kDynamic) return {}; - shape[dim] = bounds[dim]; + shape[dim] = bounds[dim] != ShapedType::kDynamic + ? bounds[dim] + : Shape::kUnboundedSize; is_dynamic[dim] = true; } else { if (bounds[dim] != ShapedType::kDynamic) return {}; diff --git a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc index 37d82730cb881f..e38dbc355d0426 100644 --- a/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/xla/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -138,8 +138,15 @@ TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}, {true, false}) .ToProto())); - // Shape cannot represent dynamic shapes. - // TODO(b/115638799): Update once Shape can support dynamic shapes. + EXPECT_THAT( + TypeToShape(RankedTensorType::get({mlir::ShapedType::kDynamic, 784}, + b.getF32Type())) + .ToProto(), + EqualsProto(ShapeUtil::MakeShape(PrimitiveType::F32, + {Shape::kUnboundedSize, 784}, + {true, false}) + .ToProto())); + EXPECT_THAT(TypeToShape(UnrankedTensorType::get(b.getF32Type())).ToProto(), EqualsProto(Shape().ToProto()));