Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick Move where clear pending IR is called to avoid crash (#5552) #5582

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,21 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
self.assertEqual(met.metric_data('ExecuteTime')[0], 11)

# Second tracing
met.clear_counters()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
self.assertEqual(met.metric_data('ExecuteTime')[0], 13)

# Verify that dynamo can handle different inputs
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 15)
self.assertEqual(met.metric_data('ExecuteTime')[0], 16)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down Expand Up @@ -539,9 +539,10 @@ def test_all_cpu_tensor(self):
# there should be 18 paramters + 1 input
self.assertGreater(len(w), 15)
self.assertIn('Found tensor with shape torch.Size', str(w[0].message))
# no XLA operation should happens. Partitioner should offload all CPU
# no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU
# ops to CPU.
self.assertEqual(len(met.counter_names()), 0)
self.assertEqual(len(met.counter_names()), 1)
self.assertIn('MarkStep', met.counter_names())


if __name__ == '__main__':
Expand Down
38 changes: 22 additions & 16 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu

debug = os.environ.get("TORCH_XLA_DEBUG") == "1"
debug = os.environ.get("XLA_DYNAMO_DEBUG") == "1"


@dataclasses.dataclass
Expand Down Expand Up @@ -322,6 +322,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):


def extract_internal(xla_model: torch.fx.GraphModule):
if debug:
for xla_arg in xla_model.xla_args:
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -471,6 +475,23 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
collector = FallBackNodeCollector(xla_model)
collector.run(*xla_args)
fallback_ops = collector.get_fallback_ops()
if debug and len(fallback_ops) > 0:
print('fallback ops are' + str(fallback_ops))

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

Expand All @@ -493,21 +514,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ void XLAGraphExecutor::ClearPendingIrs(
runtime::GetComputationClient()->CreateDataPlaceholder(
device.toString(), std::move(shape)));
tensors[i]->data()->handle = handle;
TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString()
<< " of Tensor with ID " << tensors[i]->GetUniqueId()
<< " with placeholder";
}
tensors[i]->AssignIrValue(torch::lazy::Value());
tensors[i]->data()->view = nullptr;
Expand Down