From 6ae6bc556e69b1f94f3ef36077caa7a8eb6ed3a8 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Tue, 27 Aug 2024 13:26:04 -0700 Subject: [PATCH] In dynamo optim_mode avoid unnecessary set_attr (#7915) --- torch_xla/_dynamo/dynamo_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 299a9320307d..89c25701cdb6 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -536,7 +536,6 @@ def optimized_mod(*args: tuple): nonlocal skip_checking_input_sharding_threashold nonlocal sym_constants_to_graph_vars - xla_model.xla_args = args # See [Note: Dynamo real-time input-shape cache look-up] above. xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant( args) @@ -546,6 +545,7 @@ def optimized_mod(*args: tuple): special_return_handler, xla_args_need_update) = sym_constants_to_graph_vars[sym_constants] else: + xla_model.xla_args = args (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, xla_args_need_update) = extract_graph_helper(