From a7411b22fd59b47e2217ab1d6262a4b09e94360e Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:36:03 -0700 Subject: [PATCH] Optimize dynamo dynamic shape caching (#7726) --- test/dynamo/test_dynamo_dynamic_shape.py | 16 ++++++++ torch_xla/core/dynamo_bridge.py | 48 ++++++++++++++---------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_dynamo_dynamic_shape.py b/test/dynamo/test_dynamo_dynamic_shape.py index df99139ed912..676b0494faaa 100644 --- a/test/dynamo/test_dynamo_dynamic_shape.py +++ b/test/dynamo/test_dynamo_dynamic_shape.py @@ -183,6 +183,22 @@ def test_dynamic_shape_mix_with_non_dynamic(self): self.assertEqual(met.metric_data('CompileTime')[0], 1) self.assertEqual(met.metric_data('ExecuteTime')[0], 1) + def test_dynamic_shape_no_retracing(self): + device = torch_xla.device() + # model setup + _, dummy_linear_xla, _, input_xla = self._get_linear_and_input( + 8, 10, 20, device) + compiled_linear_xla = torch.compile( + dummy_linear_xla, backend="openxla", dynamic=True) + xm.wait_device_ops() + met.clear_all() + + # first run + res_xla = compiled_linear_xla(input_xla) + # Dynamo execution should not trigger `CachedCompile` counter. If we do it likely + # means we retrace the same fx multiple times. + self.assertNotIn('CachedCompile', met.counter_names()) + def test_dynamic_shape_resnet18(self): device = torch_xla.device() diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 0a3f7ce07abb..95b77eff0da3 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -294,7 +294,12 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool: return tensor.device.type == "xla" -def extract_graph_helper(xla_model: torch.fx.GraphModule): +def extract_graph_helper(xla_model: torch.fx.GraphModule, + shapes_to_graph_vars: Dict[Tuple[int, ...], + Tuple[Any, ...]]): + # Don't reset the scope as we might be under some profiler trace scope. + xm.mark_step(reset_scope=False) + arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) # FX Graph inputs passed from Dynamo. xla_args are XLA Tensors. xla_args = xla_model.xla_args xla_args_tensor_ids = set( @@ -371,6 +376,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): dumb_return_handler = DumbReturnHandler(xla_args, args_and_out, xla_args_need_update_bool) + # There is a `mark_step` in the beginning of this function call, we need to wait + # for that to finish before retriving the device data nodes. + xm.wait_device_ops() # Collect all device data nodes that is needed to compute the args_and_out # and wrap those device data nodes inside a at::tensor(graph_input_xla_values). # Return the tensor_id that is corresponding to every device_data node as @@ -423,9 +431,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): # should be removed to avoid extra computation executed and in place updates op # mistakenlly update the input tensors. torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) - return (xla_args_sharding_spec, args_and_out, graph_hash, - arg_index_to_need_update_index, none_remover, graph_input_matcher, - dumb_return_handler, xla_args_need_update) + + vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash, + arg_index_to_need_update_index, none_remover, + graph_input_matcher, dumb_return_handler, + xla_args_need_update) + # populate the cache if model is compiled with `dynamic=True` + if not torch._dynamo.config.assume_static_by_default: + shapes_to_graph_vars[arg_input_shapes] = vars_to_return + + return vars_to_return def extract_internal(xla_model: torch.fx.GraphModule): @@ -437,8 +452,6 @@ def extract_internal(xla_model: torch.fx.GraphModule): for xla_arg in xla_model.xla_args: if isinstance(xla_arg, torch.Tensor): print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) - # Don't reset the scope as we might be under some profiler trace scope. - xm.mark_step(reset_scope=False) # [Note: Dynamo real-time input-shape cache look-up] # We maintain a mapping of input shapes to outputs of extract_graph_helper. @@ -449,11 +462,12 @@ def extract_internal(xla_model: torch.fx.GraphModule): # Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash, # arg_index_to_need_update_index, none_remover, graph_input_matcher, # dumb_return_handler, xla_args_need_update). - input_shape_mappings: Dict[Tuple[int, ...], Tuple[Any, ...]] = {} + shapes_to_graph_vars: Dict[Tuple[int, ...], Tuple[Any, ...]] = {} (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, - dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) + dumb_return_handler, + xla_args_need_update) = extract_graph_helper(xla_model, shapes_to_graph_vars) skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) @@ -468,28 +482,23 @@ def optimized_mod(*args: tuple): nonlocal dumb_return_handler nonlocal xla_args_need_update nonlocal skip_checking_input_sharding_threashold - nonlocal input_shape_mappings + nonlocal shapes_to_graph_vars # See [Note: Dynamo real-time input-shape cache look-up] above. if not torch._dynamo.config.assume_static_by_default: xla_model.xla_args = args arg_input_shapes = _get_arg_input_shapes(xla_model.xla_args) - if arg_input_shapes in input_shape_mappings: + if arg_input_shapes in shapes_to_graph_vars: (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, - xla_args_need_update) = input_shape_mappings[arg_input_shapes] + xla_args_need_update) = shapes_to_graph_vars[arg_input_shapes] else: - # First time seeing these tensors since dynamic=True. Like we do in extract_compiled_graph_helper, explicitly call mark_step for the first time. - xm.mark_step() (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, - xla_args_need_update) = extract_graph_helper(xla_model) - input_shape_mappings[arg_input_shapes] = ( - xla_args_sharding_spec, args_and_out, graph_hash, - arg_index_to_need_update_index, none_remover, graph_input_matcher, - dumb_return_handler, xla_args_need_update) + xla_args_need_update) = extract_graph_helper(xla_model, + shapes_to_graph_vars) original_device: torch.device = _get_input_arg_device(args) is_cuda_args: bool = False @@ -526,7 +535,8 @@ def optimized_mod(*args: tuple): (xla_args_sharding_spec, args_and_out_copy, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, - xla_args_need_update) = extract_graph_helper(xla_model) + xla_args_need_update) = extract_graph_helper(xla_model, + shapes_to_graph_vars) skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) else: