From 4f02bb764a0729d2bce5fea463806532abad55a2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Mon, 15 Jul 2024 11:13:23 -0400 Subject: [PATCH] Fix import test (#2931) * Fix import test * Tweak threash --- src/accelerate/utils/megatron_lm.py | 16 +++++++--------- tests/test_imports.py | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/accelerate/utils/megatron_lm.py b/src/accelerate/utils/megatron_lm.py index b6c4653d8df..552cb6d35f2 100644 --- a/src/accelerate/utils/megatron_lm.py +++ b/src/accelerate/utils/megatron_lm.py @@ -25,18 +25,10 @@ from ..optimizer import AcceleratedOptimizer from ..scheduler import AcceleratedScheduler -from .imports import is_megatron_lm_available, is_transformers_available +from .imports import is_megatron_lm_available from .operations import recursively_apply, send_to_device -if is_transformers_available(): - from transformers.modeling_outputs import ( - CausalLMOutputWithCrossAttentions, - Seq2SeqLMOutput, - SequenceClassifierOutput, - ) - - if is_megatron_lm_available(): from megatron import ( get_args, @@ -467,6 +459,8 @@ def __init__(self, accelerator, args): if not args.model_return_dict: self.model_output_class = None else: + from transformers.modeling_outputs import SequenceClassifierOutput + self.model_output_class = SequenceClassifierOutput def get_batch_func(self, accelerator, megatron_dataset_flag): @@ -614,6 +608,8 @@ def __init__(self, accelerator, args): if not args.model_return_dict: self.model_output_class = None else: + from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + self.model_output_class = CausalLMOutputWithCrossAttentions def get_batch_func(self, accelerator, megatron_dataset_flag): @@ -737,6 +733,8 @@ def __init__(self, accelerator, args): if not args.model_return_dict: self.model_output_class = None else: + from transformers.modeling_outputs import Seq2SeqLMOutput + self.model_output_class = Seq2SeqLMOutput @staticmethod diff --git a/tests/test_imports.py b/tests/test_imports.py index 490c409e774..d3ad7caa1fa 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -62,22 +62,22 @@ def test_base_import(self): output = run_import_time("import accelerate") data = read_import_profile(output) total_time = calculate_total_time(data) - pct_more = total_time / self.pytorch_time - # Base import should never be more than 10% slower than raw torch import - err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n" + pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100 + # Base import should never be more than 20% slower than raw torch import + err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\n" sorted_data = sort_nodes_by_total_time(data) - paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7) + paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7) err_msg += f"\n{convert_list_to_string(paths_above_threshold)}" - self.assertLess(pct_more, 1.2, err_msg) + self.assertLess(pct_more, 20, err_msg) def test_cli_import(self): output = run_import_time("from accelerate.commands.launch import launch_command_parser") data = read_import_profile(output) total_time = calculate_total_time(data) - pct_more = total_time / self.pytorch_time - # Base import should never be more than 10% slower than raw torch import - err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n" + pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100 + # Base import should never be more than 20% slower than raw torch import + err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\n" sorted_data = sort_nodes_by_total_time(data) - paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7) + paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7) err_msg += f"\n{convert_list_to_string(paths_above_threshold)}" - self.assertLess(pct_more, 1.2, err_msg) + self.assertLess(pct_more, 20, err_msg)