Skip to content

Commit

Permalink
add dynamo config skip_input_data_check (pytorch#7913)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and yitongh committed Oct 11, 2024
1 parent b3e2e11 commit 86f8b0c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 12 deletions.
27 changes: 27 additions & 0 deletions test/dynamo/test_dynamo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import torch_xla
import unittest
from torch_xla._dynamo import config


class DynamoconfigTest(unittest.TestCase):

def dummy_test(self, a):
return a.cos().sin()

def test_config_skip_input_data_check(self):
device = torch_xla.device()
print(config.skip_input_data_check)
config.skip_input_data_check = True
compiled_dummy = torch.compile(self.dummy_test, backend="openxla")
t1 = torch.randn(3, 4, device=device)
compiled_dummy(t1)
t2 = torch.randn(3, 4, device=device)
t2 += 5
with self.assertRaisesRegex(
RuntimeError, r'input data to dynamo graph can not be a pending ir'):
compiled_dummy(t2)


if __name__ == '__main__':
test = unittest.main()
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_bridge.py"
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
Expand Down
1 change: 1 addition & 0 deletions torch_xla/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import torch_xla._dynamo.config as config
6 changes: 6 additions & 0 deletions torch_xla/_dynamo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch_xla

# Whether to skip checking input is a device data or not in the optim_mod.
# Enabling it will reduce the overhead a bit but will throw a runtime error
# if input is a pending IR.
skip_input_data_check = False
26 changes: 14 additions & 12 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch._inductor.fx_passes.post_grad import ConstructorMoverPass

from torch.utils import _pytree as pytree
from torch_xla._dynamo import config

import torch_xla
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -558,18 +559,19 @@ def optimized_mod(*args: tuple):
if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
input_tensors_to_sync = [
xla_args_tensor_only[i] for i, x in enumerate(
torch_xla._XLAC._check_tensor_need_materialization(
xla_args_tensor_only)) if x
]

if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)
if not config.skip_input_data_check:
# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
input_tensors_to_sync = [
xla_args_tensor_only[i] for i, x in enumerate(
torch_xla._XLAC._check_tensor_need_materialization(
xla_args_tensor_only)) if x
]

if len(input_tensors_to_sync) > 0:
torch_xla._XLAC._xla_increment_counter('DynamoSyncInputExecuteTime', 1)
torch_xla._XLAC._xla_sync_multi(
input_tensors_to_sync, devices=[], wait=True, sync_xla_data=True)

# If input sharding has changed from the previous program, dynamo current can
# not detect this. It will mistakenly believe the program is the same. We need
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,13 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
for (auto& ivalue : graph_inputs) {
torch::lazy::BackendDataPtr dataptr;
if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) {
bool is_non_data_ir =
xla_tensor_ptr->CurrentIrValue().node != nullptr &&
(torch_xla::DeviceData::Cast(
xla_tensor_ptr->CurrentIrValue().node.get()) == nullptr);
XLA_CHECK(!is_non_data_ir)
<< "input data to dynamo graph can not be a pending ir, please set "
"`torch_xla._dynamo.config.skip_input_data_check` to False";
dataptr = xla_tensor_ptr->GetXlaData();
} else {
XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD)
Expand Down

0 comments on commit 86f8b0c

Please sign in to comment.