Skip to content

Commit

Permalink
Optimize dynamo dynamic shape caching (pytorch#7726)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and yitongh committed Oct 11, 2024
1 parent 3704c3b commit a7411b2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
16 changes: 16 additions & 0 deletions test/dynamo/test_dynamo_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ def test_dynamic_shape_mix_with_non_dynamic(self):
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_no_retracing(self):
device = torch_xla.device()
# model setup
_, dummy_linear_xla, _, input_xla = self._get_linear_and_input(
8, 10, 20, device)
compiled_linear_xla = torch.compile(
dummy_linear_xla, backend="openxla", dynamic=True)
xm.wait_device_ops()
met.clear_all()

# first run
res_xla = compiled_linear_xla(input_xla)
# Dynamo execution should not trigger `CachedCompile` counter. If we do it likely
# means we retrace the same fx multiple times.
self.assertNotIn('CachedCompile', met.counter_names())

def test_dynamic_shape_resnet18(self):
device = torch_xla.device()

Expand Down
48 changes: 29 additions & 19 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,12 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == "xla"


def extract_graph_helper(xla_model: torch.fx.GraphModule):
def extract_graph_helper(xla_model: torch.fx.GraphModule,
shapes_to_graph_vars: Dict[Tuple[int, ...],
Tuple[Any, ...]]):
# Don't reset the scope as we might be under some profiler trace scope.
xm.mark_step(reset_scope=False)
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
# FX Graph inputs passed from Dynamo. xla_args are XLA Tensors.
xla_args = xla_model.xla_args
xla_args_tensor_ids = set(
Expand Down Expand Up @@ -371,6 +376,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
dumb_return_handler = DumbReturnHandler(xla_args, args_and_out,
xla_args_need_update_bool)

# There is a `mark_step` in the beginning of this function call, we need to wait
# for that to finish before retriving the device data nodes.
xm.wait_device_ops()
# Collect all device data nodes that is needed to compute the args_and_out
# and wrap those device data nodes inside a at::tensor(graph_input_xla_values).
# Return the tensor_id that is corresponding to every device_data node as
Expand Down Expand Up @@ -423,9 +431,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
# should be removed to avoid extra computation executed and in place updates op
# mistakenlly update the input tensors.
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
return (xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)

vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover,
graph_input_matcher, dumb_return_handler,
xla_args_need_update)
# populate the cache if model is compiled with `dynamic=True`
if not torch._dynamo.config.assume_static_by_default:
shapes_to_graph_vars[arg_input_shapes] = vars_to_return

return vars_to_return


def extract_internal(xla_model: torch.fx.GraphModule):
Expand All @@ -437,8 +452,6 @@ def extract_internal(xla_model: torch.fx.GraphModule):
for xla_arg in xla_model.xla_args:
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
# Don't reset the scope as we might be under some profiler trace scope.
xm.mark_step(reset_scope=False)

# [Note: Dynamo real-time input-shape cache look-up]
# We maintain a mapping of input shapes to outputs of extract_graph_helper.
Expand All @@ -449,11 +462,12 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
input_shape_mappings: Dict[Tuple[int, ...], Tuple[Any, ...]] = {}
shapes_to_graph_vars: Dict[Tuple[int, ...], Tuple[Any, ...]] = {}

(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model, shapes_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)

Expand All @@ -468,28 +482,23 @@ def optimized_mod(*args: tuple):
nonlocal dumb_return_handler
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold
nonlocal input_shape_mappings
nonlocal shapes_to_graph_vars

# See [Note: Dynamo real-time input-shape cache look-up] above.
if not torch._dynamo.config.assume_static_by_default:
xla_model.xla_args = args
arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args)
if arg_input_shapes in input_shape_mappings:
if arg_input_shapes in shapes_to_graph_vars:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = input_shape_mappings[arg_input_shapes]
xla_args_need_update) = shapes_to_graph_vars[arg_input_shapes]
else:
# First time seeing these tensors since dynamic=True. Like we do in extract_compiled_graph_helper, explicitly call mark_step for the first time.
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
input_shape_mappings[arg_input_shapes] = (
xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update)
xla_args_need_update) = extract_graph_helper(xla_model,
shapes_to_graph_vars)

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
Expand Down Expand Up @@ -526,7 +535,8 @@ def optimized_mod(*args: tuple):
(xla_args_sharding_spec, args_and_out_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model)
xla_args_need_update) = extract_graph_helper(xla_model,
shapes_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
else:
Expand Down

0 comments on commit a7411b2

Please sign in to comment.