Skip to content

Commit

Permalink
Fixup multiple model DS tests (#3131)
Browse files Browse the repository at this point in the history
* Multiple model multi GPU fixed, different issues than torch

* Fix multiple-model issues
  • Loading branch information
muellerzr authored Sep 26, 2024
1 parent 4305033 commit 018a99e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup

from accelerate import Accelerator, DeepSpeedPlugin, DistributedType
from accelerate.state import AcceleratorState
from accelerate.utils.deepspeed import get_active_deepspeed_plugin


Expand Down Expand Up @@ -323,6 +324,7 @@ def main():
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
single_model_training(config, args)
AcceleratorState._reset_state(True)
multiple_model_training(config, args)


Expand Down
50 changes: 26 additions & 24 deletions tests/deepspeed/test_deepspeed_multiple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from transformers import AutoModelForCausalLM

from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.commands.launch import launch_command, launch_command_parser
from accelerate.test_utils.testing import (
AccelerateTestCase,
execute_subprocess_async,
path_in_accelerate_package,
require_deepspeed,
require_huggingface_suite,
Expand All @@ -42,6 +42,7 @@
@require_deepspeed
@require_non_cpu
class DeepSpeedConfigIntegration(AccelerateTestCase):
parser = launch_command_parser()
test_scripts_folder = path_in_accelerate_package("test_utils", "scripts", "external_deps")

def setUp(self):
Expand Down Expand Up @@ -145,32 +146,33 @@ def test_multiple_accelerators(self):
_ = Accelerator(deepspeed_plugin=ds_zero3)

def test_prepare_multiple_models_zero3_inference(self):
ds_plugins = self.get_ds_plugins(zero3_inference=True)
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
# Using Zero-2 first
model1 = self.model_init()
optimizer = DummyOptim(model1.parameters())
scheduler = DummyScheduler(optimizer)

dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader)
accelerator.state.select_deepspeed_plugin("zero3")
model2 = self.model_init()
with self.assertLogs(level="WARNING") as captured:
model2 = accelerator.prepare(model2)
self.assertIn(
"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.",
captured.output[0],
)

assert accelerator.deepspeed_engine_wrapped.engine is model1
with patch_environment(**self.dist_env):
ds_plugins = self.get_ds_plugins(zero3_inference=True)
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
# Using Zero-2 first
model1 = self.model_init()
optimizer = DummyOptim(model1.parameters())
scheduler = DummyScheduler(optimizer)

dataset = RegressionDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)
model1, optimizer, scheduler, dataloader = accelerator.prepare(model1, optimizer, scheduler, dataloader)
accelerator.state.select_deepspeed_plugin("zero3")
model2 = self.model_init()
with self.assertLogs(level="WARNING") as captured:
model2 = accelerator.prepare(model2)
self.assertIn(
"A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance.",
captured.output[0],
)

assert accelerator.deepspeed_engine_wrapped.engine is model1

@require_huggingface_suite
@require_multi_device
@slow
def test_train_multiple_models(self):
self.test_file_path = self.test_scripts_folder / "test_ds_multiple_model.py"
cmd = ["accelerate", "launch", "--num_processes=2", "--num_machines=1", self.test_file_path]
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd)
args = ["--num_processes=2", "--num_machines=1", "--main_process_port=10999", str(self.test_file_path)]
args = self.parser.parse_args(args)
launch_command(args)

0 comments on commit 018a99e

Please sign in to comment.