Skip to content

Commit

Permalink
Add patch for unbounded dynamism (#6783)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
  • Loading branch information
ghpvnist and lsy323 authored Mar 21, 2024
1 parent 3eeb15d commit 8d579f9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:unbounded_dynamism.diff",
],
strip_prefix = "xla-25c8a6781af6be51d3bc43a0953b07803ab761ea",
urls = [
Expand Down
40 changes: 40 additions & 0 deletions openxla_patches/unbounded_dynamism.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc
index 57f0529b5..5f8a1c582 100644
--- a/xla/client/xla_builder.cc
+++ b/xla/client/xla_builder.cc
@@ -1182,12 +1182,16 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
this, rhs, lhs, *lhs_shape));
}
} else {
- TF_ASSIGN_OR_RETURN(UnboundedBroadcastResult broadcast_result,
- BroadcastToOutputShapeWithUnbounded(
- this, lhs, *lhs_shape, rhs, *rhs_shape, shape,
- broadcast_dimensions));
- updated_lhs = broadcast_result.lhs;
- updated_rhs = broadcast_result.rhs;
+ if (!ShapeUtil::SameDimensions(*lhs_shape, *rhs_shape)) {
+ Shape output_shape = shape;
+ output_shape.set_element_type(lhs_shape->element_type());
+ TF_ASSIGN_OR_RETURN(UnboundedBroadcastResult broadcast_result,
+ BroadcastToOutputShapeWithUnbounded(
+ this, lhs, *lhs_shape, rhs, *rhs_shape,
+ output_shape, broadcast_dimensions));
+ updated_lhs = broadcast_result.lhs;
+ updated_rhs = broadcast_result.rhs;
+ }
}
}

diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc
index ee9d100a3..231cd6baf 100644
--- a/xla/client/xla_builder_test.cc
+++ b/xla/client/xla_builder_test.cc
@@ -2436,6 +2436,8 @@ INSTANTIATE_TEST_SUITE_P(
/*broadcast_dimensions=*/{}, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Mul},
{"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array,
"f32[?, 10]", &Mul},
+ {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array,
+ "pred[?, 10]", &Ne},
{"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]",
/*broadcast_dimensions=*/{}, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Pow},
{"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array,
3 changes: 1 addition & 2 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def test_mul_scalar(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded dynamism is not supported.")
def test_ne_scalar(self):

class M(torch.nn.Module):
Expand All @@ -392,7 +391,7 @@ def forward(self, x):
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text)
re.search(r"%arg.: tensor<\?x5xi64>.*->.*tensor<\?x5xi32>", shlo_text)
is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit 8d579f9

Please sign in to comment.