Skip to content

Commit

Permalink
add unbounded dynamism test for some aten ops
Browse files Browse the repository at this point in the history
format

fix comment for skipped tests

cover mul

(cherry picked from commit f55abc88ae361e89da675a1aa1e4a19e7a5c762a)

cover mul

(cherry picked from commit 30abe2be43defc25db8954c525d34f7f3de35292)

add missing tests to ci scripts

yapf

fix scalar type

(cherry picked from commit 8526b2091ffafccf6972ecba3c111d1b0869621e)

disable addmm test

disable mark pattern api in gh ci, due to tf dep

enable conv dynamism

support addmm

enable softmax dynamism

update comment for slice

add slice support, need converter change

update test script

take dynamic shape in save model export api

verify lowering by adding tfl inference in tests

remove debug pritn

add assertion of sliced dim in select lowering

remove log in conv, remove assertion in select

re-enable test

add select fx pass

add no op slice removal pass

add fx passes

add tests'

support layernorm

add vit export scripot

fix ep callable

enable gelu test

add export script

support dynamic view with sym dim on dims other than BS

add tests for gemma export

support unsqueeze

support softmax reduction on dynamic dim

support unbounded index (unfinished)

support dynamic expand

add groupnorm

add conv1d support, add dynamism (partially) to view

add wav2vec2 export script

add cumsum test, ne test

remove existing tests

change from crlf to lf

add checks on view

move stablehlo test util script

remove debugging print

add more assertions to fx passes

remove test print

add docstr to dynamic op

make export script more concise

remove debug print

add comments to shape inference

fix linter

fix test util path

yapf

remove stack

yapf

update export script

fix meta val not available in some nodes
  • Loading branch information
Siyuan Liu committed Mar 20, 2024
1 parent 7e0d3a5 commit eb7420f
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 42 deletions.
47 changes: 47 additions & 0 deletions test/stablehlo/export_vit_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
from typing import Callable, List, Tuple, Type, Union

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch_xla
from torch.export import Dim, export
from torch.utils import _pytree as pytree
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import \
save_torch_module_as_tf_saved_model
from transformers import ViTForImageClassification

os.environ['EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM'] = '1'


class ViTForImageClassificationModelWrapper(nn.Module):

def __init__(self, model_name):
super().__init__()
self.m = ViTForImageClassification.from_pretrained(model_name)

def forward(self, img):
return self.m(pixel_values=img).logits


model = ViTForImageClassificationModelWrapper(
'google/vit-base-patch16-224').eval()
args = (torch.rand(10, 3, 224, 224),)
dynamic_shapes = ({0: Dim("dim")},)

# Export to saved_model
tmp_dir = "/tmp/vit-export/vit-1"
save_torch_module_as_tf_saved_model(
model, args, tmp_dir, dynamic_shapes=dynamic_shapes)

# Verify numeric accuracy with an input with a different BS.
args = (torch.rand(2, 3, 224, 224),)
loaded_m = tf.saved_model.load(tmp_dir)
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()),
args)
tf_output = loaded_m.f(*tf_input)
with torch.no_grad():
torch_output = model(*args)
print(np.max(torch_output.numpy() - tf_output[0].numpy()))
44 changes: 44 additions & 0 deletions test/stablehlo/export_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os

import numpy as np
import tensorflow as tf
import torch
import torch_xla
from torch.export import Dim, export
from torch.utils import _pytree as pytree
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import save_torch_module_as_tf_saved_model
from transformers import Wav2Vec2ForCTC

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"


class ModelWrapper(torch.nn.Module):

def __init__(self):
super().__init__()
self._model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

def forward(self, input):
r = self._model(input)
return r.logits


model = ModelWrapper().eval()
args = (torch.rand(3, 800),)
dynamic_shapes = ({0: Dim("bs")},)
ep = export(model, args=args, dynamic_shapes=dynamic_shapes)

tmp_dir = "/tmp/wav2vec2-export/tmp"
save_torch_module_as_tf_saved_model(
model, args, tmp_dir, dynamic_shapes=dynamic_shapes)

# Verify numeric accuracy with an input with a different BS.
args = (torch.rand(2, 800),)
loaded_m = tf.saved_model.load(tmp_dir)
tf_input = pytree.tree_map_only(torch.Tensor, lambda x: tf.constant(x.numpy()),
args)
tf_output = loaded_m.f(*tf_input)
with torch.no_grad():
torch_output = model(*args)
print(np.max(torch_output.numpy() - tf_output[0].numpy()))
5 changes: 0 additions & 5 deletions test/stablehlo/test_export_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def forward(self, x):
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
replace_dynamic_view_with_xla_op(ep.graph_module)
print(ep)
ep.graph_module.recompile()
print(ep.graph_module.code)
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))
Expand Down Expand Up @@ -153,9 +151,7 @@ def forward(self, x, range):
dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")})
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
print(ep)
replace_dynamic_expand_with_xla_op(ep.graph_module)
print(ep)
ep.graph_module.recompile()
self.assertTrue('xla.dynamic_expand' in ep.graph_module.code)
out2 = ep.module()(*args)
Expand Down Expand Up @@ -221,7 +217,6 @@ def forward(self, x):
ep.graph_module.recompile()
self.assertFalse('aten.native_group_norm' in ep.graph_module.code)
after_decomp_ep_out = ep.module()(*export_args)
# print(before_decomp_ep_out - after_decomp_ep_out)
self.assertTrue(
torch.allclose(before_decomp_ep_out, after_decomp_ep_out, atol=1e-6))

Expand Down
2 changes: 0 additions & 2 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def test_conv(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded dynamism is not supported yet.")
def test_conv1d(self):
args = (
torch.rand((3, 1, 800)),
Expand Down Expand Up @@ -571,7 +570,6 @@ def test_sub(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded dynamism is not supported yet.")
def test_softmax_reduce_on_dynamic_dim(self):
args = (torch.rand((1, 8, 128, 3)), -1, False)
dynamic_shapes = ([{3: Dim("dim")}, None, None],)
Expand Down
56 changes: 52 additions & 4 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ xla::XlaOp BuildView(xla::XlaOp input, absl::Span<const int64_t> output_sizes) {
return XlaHelpers::DynamicReshape(input, complete_output_sizes);
}

xla::XlaOp BuildUnboundedDynamicView(
xla::XlaOp input, const xla::Shape& input_shape,
const absl::Span<const int64_t>& output_sizes) {
// Only Support BS is dynamic now.
const absl::Span<const int64_t> input_dims = input_shape.dimensions();
XLA_CHECK(std::count(input_dims.cbegin(), input_dims.cend(),
xla::Shape::kUnboundedSize) == 1 &&
input_shape.is_unbounded_dynamic_dimension(0))
<< "Only BS of the input to view op can be unbounded dynamic.";

XLA_CHECK(std::accumulate(input_dims.cbegin() + 1, input_dims.cend(), 1,
std::multiplies<int64_t>()) ==
std::accumulate(output_sizes.cbegin() + 1, output_sizes.cend(), 1,
std::multiplies<int64_t>()))
<< "Dimensions of view input and output don't match.";

const int src_index = 0;
const int target_index = 0;
xla::XlaOp dynamic_dim =
xla::Reshape(xla::GetDimensionSize(input, src_index), {1});

std::vector<xla::XlaOp> concat_ops;
concat_ops.push_back(dynamic_dim);
std::vector<int32_t> static_input_dims_vec(output_sizes.begin() + 1,
output_sizes.end());
concat_ops.push_back(xla::ConstantR1(
input.builder(), absl::Span<const int32_t>(static_input_dims_vec)));
xla::XlaOp final_broadcast_dimensions =
xla::ConcatInDim(input.builder(), absl::Span<xla::XlaOp>(concat_ops), 0);

// Final shape
std::vector<int64_t> output_sizes_vec(output_sizes.begin(),
output_sizes.end());
output_sizes_vec[target_index] = xla::Shape::kUnboundedSize;
std::vector<bool> output_dynamic(output_sizes_vec.size(), false);
output_dynamic[target_index] = true;
xla::Shape final_shape = xla::ShapeUtil::MakeShape(
input_shape.element_type(), output_sizes_vec, output_dynamic);

xla::XlaOp result =
xla::CustomCall(input.builder(), "mhlo.dynamic_reshape",
{input, final_broadcast_dimensions}, final_shape);
return result;
}

xla::XlaOp SetDimensionSizes(xla::XlaOp input,
absl::Span<const xla::XlaOp> symbolic_output_sizes,
std::vector<bool> dynamic_dims) {
Expand Down Expand Up @@ -216,10 +261,13 @@ xla::XlaOp BuildStack(absl::Span<const xla::XlaOp> inputs, int64_t dim) {
XLA_CHECK_GT(inputs.size(), 0);
std::vector<xla::XlaOp> reshaped_inputs;
for (size_t i = 0; i < inputs.size(); ++i) {
auto input_size = XlaHelpers::SizesOfXlaOp(inputs[i]);
input_size.insert(input_size.begin() + dim, 1);
reshaped_inputs.push_back(
XlaHelpers::DynamicReshape(inputs[i], input_size));
const xla::XlaOp& input = inputs[i];
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(inputs[i]);
const std::vector<int64_t> input_sizes =
XlaHelpers::SizesOfXlaOp(inputs[i]);
std::vector<int64_t> output_sizes = input_sizes;
output_sizes.insert(output_sizes.begin() + dim, 1);
reshaped_inputs.push_back(XlaHelpers::DynamicReshape(input, output_sizes));
}
return xla::ConcatInDim(inputs[0].builder(), reshaped_inputs, dim);
}
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ std::vector<int64_t> GetCompleteShape(absl::Span<const int64_t> output_sizes,
// output size.
xla::XlaOp BuildView(xla::XlaOp input, absl::Span<const int64_t> output_sizes);

// Build View with unbounded dynamism input.
xla::XlaOp BuildUnboundedDynamicView(
xla::XlaOp input, const xla::Shape& input_shape,
const absl::Span<const int64_t>& output_sizes);

// Return a new XlaOp that reflects dynamic dimensions
xla::XlaOp SetDimensionSizes(xla::XlaOp input,
absl::Span<const xla::XlaOp> symbolic_output_sizes,
Expand Down
9 changes: 8 additions & 1 deletion torch_xla/csrc/ops/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/shape_helper.h"
#include "xla/shape_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -42,7 +43,13 @@ ViewOp::ViewOp(const torch::lazy::Value& input, xla::Shape output_shape)

XlaOpVector ViewOp::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = BuildView(input, output_size_);
xla::XlaOp output;
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
if (!input_shape.is_unbounded_dynamic()) {
output = BuildView(input, output_size_);
} else {
output = BuildUnboundedDynamicView(input, input_shape, output_size_);
}
return ReturnOp(output, loctx);
}

Expand Down
41 changes: 19 additions & 22 deletions torch_xla/csrc/softmax_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,22 @@ static std::string StringifyBroadcastDimensions(
}

static xla::XlaOp BuildBroadcastForReducedLogits(xla::XlaOp reduced_logits,
const xla::Shape& logits_shape,
int dim) {
// Assume BS is unbounded.
std::vector<int64_t> reduced_logits_sizes =
XlaHelpers::SizesOfXlaOp(reduced_logits);
XLA_CHECK(std::count(reduced_logits_sizes.begin(), reduced_logits_sizes.end(),
xla::Shape::kUnboundedSize) == 1 &&
reduced_logits_sizes[0] == xla::Shape::kUnboundedSize)
<< "Only the BS of the logits can be unbounded dynamic.";
xla::XlaOp dynamic_dim_tensor =
xla::Reshape(xla::GetDimensionSize(reduced_logits, 0), {1});
std::vector<int32_t> static_input_dims_vec(
logits_shape.dimensions().begin() + 1, logits_shape.dimensions().end());
xla::XlaOp static_input_dims =
xla::ConstantR1(reduced_logits.builder(),
absl::Span<const int32_t>(static_input_dims_vec));
xla::XlaOp logits, int dim) {
xla::Shape logits_shape = ShapeHelper::ShapeOfXlaOp(logits);
const std::vector<int64_t> logits_sizes(logits_shape.dimensions().begin(),
logits_shape.dimensions().end());
std::vector<xla::XlaOp> concat_ops;
for (size_t i = 0; i < logits_sizes.size(); ++i) {
if (logits_sizes.at(i) == xla::Shape::kUnboundedSize) {
concat_ops.push_back(xla::Reshape(xla::GetDimensionSize(logits, i), {1}));
} else {
concat_ops.push_back(xla::ConstantR1(
logits.builder(), absl::Span<const int32_t>(
{static_cast<int32_t>(logits_sizes.at(i))})));
}
}
xla::XlaOp final_broadcast_dimensions = xla::ConcatInDim(
reduced_logits.builder(), {dynamic_dim_tensor, static_input_dims}, 0);

reduced_logits.builder(), absl::Span<const xla::XlaOp>(concat_ops), 0);
// Output shape
std::vector<int64_t> op_broadcast_dims(logits_shape.dimensions().size() - 1);
std::iota(op_broadcast_dims.begin(), op_broadcast_dims.begin() + dim, 0);
Expand All @@ -73,7 +70,6 @@ static xla::XlaOp BuildBroadcastForReducedLogits(xla::XlaOp reduced_logits,

SoftMaxPartials LogSoftmaxPartials(xla::XlaOp logits, int64_t dim) {
const xla::Shape& logits_shape = ShapeHelper::ShapeOfXlaOp(logits);
bool is_unbounded_dynamic = logits_shape.is_unbounded_dynamic();
std::vector<int64_t> broadcast_dimensions =
BroadcastDimensions(logits_shape.rank(), dim);
xla::XlaComputation max_func =
Expand All @@ -83,9 +79,10 @@ SoftMaxPartials LogSoftmaxPartials(xla::XlaOp logits, int64_t dim) {
xla::XlaBuilder* builder = logits.builder();
xla::XlaOp logits_max = xla::Reduce(
logits, xla::ConstantLiteral(builder, min_value), max_func, {dim});
bool is_unbounded_dynamic = logits_shape.is_unbounded_dynamic();
if (is_unbounded_dynamic) {
xla::Shape logits_max_shape = ShapeHelper::ShapeOfXlaOp(logits_max);
logits_max = BuildBroadcastForReducedLogits(logits_max, logits_shape, dim);
logits_max = BuildBroadcastForReducedLogits(logits_max, logits, dim);
}
xla::XlaOp shifted_logits =
is_unbounded_dynamic ? xla::Sub(logits, logits_max)
Expand Down Expand Up @@ -132,8 +129,8 @@ xla::XlaOp BuildLogSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output,
xla::XlaOp BuildSoftmax(xla::XlaOp logits, int64_t dim) {
SoftMaxPartials parts = LogSoftmaxPartials(logits, dim);
if (ShapeHelper::ShapeOfXlaOp(logits).is_unbounded_dynamic()) {
xla::XlaOp broadcasted_reduce = BuildBroadcastForReducedLogits(
parts.reduce, ShapeHelper::ShapeOfXlaOp(logits), dim);
xla::XlaOp broadcasted_reduce =
BuildBroadcastForReducedLogits(parts.reduce, logits, dim);
return xla::Div(parts.exp_shifted, broadcasted_reduce);
} else {
return xla::Div(parts.exp_shifted, parts.reduce,
Expand Down
19 changes: 11 additions & 8 deletions torch_xla/experimental/unbounded_dynamism_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,17 @@ def replace_dynamic_expand_with_xla_op(gm: GraphModule):
if len(symbolic_dims_sizes) == 0:
continue
assert len(symbolic_dims_sizes) == 1
src_sizes = n.args[0].meta['val'].size()
expanded_sizes = n.args[1]
assert len(src_sizes) == len(expanded_sizes)
for i in range(len(src_sizes)):
if not isinstance(src_sizes[i], int) and not isinstance(
expanded_sizes[i], int):
assert src_sizes[i] == expanded_sizes[i].meta[
'val'], "Expanded symbolic dim to a different symbolic size is not supported."
if 'val' in n.args[0].meta:
# Some nodes may not have meta['val'] stored.
# Skip the check for now.
src_sizes = n.args[0].meta['val'].size()
expanded_sizes = n.args[1]
assert len(src_sizes) == len(expanded_sizes)
for i in range(len(src_sizes)):
if not isinstance(src_sizes[i], int) and not isinstance(
expanded_sizes[i], int):
assert src_sizes[i] == expanded_sizes[i].meta[
'val'], "Expanded symbolic dim to a different symbolic size is not supported."
for dim, sym_size_node in symbolic_dims_sizes:
assert sym_size_node.op == "call_function" and sym_size_node.target == aten.sym_size.int
dynamic_src = sym_size_node.args[0]
Expand Down

0 comments on commit eb7420f

Please sign in to comment.