diff --git a/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py b/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py index bdfb08c715f..3729ecf4c72 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py @@ -30,6 +30,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup from accelerate import Accelerator, DeepSpeedPlugin, DistributedType +from accelerate.state import AcceleratorState from accelerate.utils.deepspeed import get_active_deepspeed_plugin @@ -323,6 +324,7 @@ def main(): args = parser.parse_args() config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} single_model_training(config, args) + AcceleratorState._reset_state(True) multiple_model_training(config, args) diff --git a/tests/deepspeed/test_deepspeed_multiple_model.py b/tests/deepspeed/test_deepspeed_multiple_model.py index f26f27b6f5b..f26f40debb3 100644 --- a/tests/deepspeed/test_deepspeed_multiple_model.py +++ b/tests/deepspeed/test_deepspeed_multiple_model.py @@ -21,9 +21,9 @@ from transformers import AutoModelForCausalLM from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.commands.launch import launch_command, launch_command_parser from accelerate.test_utils.testing import ( AccelerateTestCase, - execute_subprocess_async, path_in_accelerate_package, require_deepspeed, require_huggingface_suite, @@ -42,6 +42,7 @@ @require_deepspeed @require_non_cpu class DeepSpeedConfigIntegration(AccelerateTestCase): + parser = launch_command_parser() test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps") def setUp(self): @@ -145,32 +146,33 @@ def test_multiple_accelerators(self): _ = Accelerator(deepspeed_plugin=ds_zero3) def test_prepare_multiple_models_zero3_inference(self): - ds_plugins = self.get_ds_plugins(zero3_inference=True) - accelerator = Accelerator(deepspeed_plugin=ds_plugins) - # Using Zero-2 first - model1 = self.model_init() - optimizer = DummyOptim(model1.parameters()) - scheduler = DummyScheduler(optimizer) - - dataset = RegressionDataset() - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) - model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader) - accelerator.state.select_deepspeed_plugin("zero3") - model2 = self.model_init() - with self.assertLogs(level="WARNING") as captured: - model2 = accelerator.prepare(model2) - self.assertIn( - "A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.", - captured.output[0], - ) - - assert accelerator.deepspeed_engine_wrapped.engine is model1 + with patch_environment(**self.dist_env): + ds_plugins = self.get_ds_plugins(zero3_inference=True) + accelerator = Accelerator(deepspeed_plugin=ds_plugins) + # Using Zero-2 first + model1 = self.model_init() + optimizer = DummyOptim(model1.parameters()) + scheduler = DummyScheduler(optimizer) + + dataset = RegressionDataset() + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) + model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader) + accelerator.state.select_deepspeed_plugin("zero3") + model2 = self.model_init() + with self.assertLogs(level="WARNING") as captured: + model2 = accelerator.prepare(model2) + self.assertIn( + "A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.", + captured.output[0], + ) + + assert accelerator.deepspeed_engine_wrapped.engine is model1 @require_huggingface_suite @require_multi_device @slow def test_train_multiple_models(self): self.test_file_path = self.test_scripts_folder / "test_ds_multiple_model.py" - cmd = ["accelerate", "launch", "--num_processes=2", "--num_machines=1", self.test_file_path] - with patch_environment(omp_num_threads=1): - execute_subprocess_async(cmd) + args = ["--num_processes=2", "--num_machines=1", "--main_process_port=10999", str(self.test_file_path)] + args = self.parser.parse_args(args) + launch_command(args)