Skip to content

Commit

Permalink
Fix import test (#2931)
Browse files Browse the repository at this point in the history
* Fix import test

* Tweak threash
  • Loading branch information
muellerzr authored Jul 15, 2024
1 parent 709fd1e commit 4f02bb7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
16 changes: 7 additions & 9 deletions src/accelerate/utils/megatron_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,10 @@

from ..optimizer import AcceleratedOptimizer
from ..scheduler import AcceleratedScheduler
from .imports import is_megatron_lm_available, is_transformers_available
from .imports import is_megatron_lm_available
from .operations import recursively_apply, send_to_device


if is_transformers_available():
from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
SequenceClassifierOutput,
)


if is_megatron_lm_available():
from megatron import (
get_args,
Expand Down Expand Up @@ -467,6 +459,8 @@ def __init__(self, accelerator, args):
if not args.model_return_dict:
self.model_output_class = None
else:
from transformers.modeling_outputs import SequenceClassifierOutput

self.model_output_class = SequenceClassifierOutput

def get_batch_func(self, accelerator, megatron_dataset_flag):
Expand Down Expand Up @@ -614,6 +608,8 @@ def __init__(self, accelerator, args):
if not args.model_return_dict:
self.model_output_class = None
else:
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

self.model_output_class = CausalLMOutputWithCrossAttentions

def get_batch_func(self, accelerator, megatron_dataset_flag):
Expand Down Expand Up @@ -737,6 +733,8 @@ def __init__(self, accelerator, args):
if not args.model_return_dict:
self.model_output_class = None
else:
from transformers.modeling_outputs import Seq2SeqLMOutput

self.model_output_class = Seq2SeqLMOutput

@staticmethod
Expand Down
20 changes: 10 additions & 10 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@ def test_base_import(self):
output = run_import_time("import accelerate")
data = read_import_profile(output)
total_time = calculate_total_time(data)
pct_more = total_time / self.pytorch_time
# Base import should never be more than 10% slower than raw torch import
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100
# Base import should never be more than 20% slower than raw torch import
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\n"
sorted_data = sort_nodes_by_total_time(data)
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 1.2, err_msg)
self.assertLess(pct_more, 20, err_msg)

def test_cli_import(self):
output = run_import_time("from accelerate.commands.launch import launch_command_parser")
data = read_import_profile(output)
total_time = calculate_total_time(data)
pct_more = total_time / self.pytorch_time
# Base import should never be more than 10% slower than raw torch import
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
pct_more = (total_time - self.pytorch_time) / self.pytorch_time * 100
# Base import should never be more than 20% slower than raw torch import
err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more:.2f}%), please check the attached `tuna` profile:\n"
sorted_data = sort_nodes_by_total_time(data)
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
paths_above_threshold = get_paths_above_threshold(sorted_data, 0.05, max_depth=7)
err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
self.assertLess(pct_more, 1.2, err_msg)
self.assertLess(pct_more, 20, err_msg)

0 comments on commit 4f02bb7

Please sign in to comment.