Skip to content

Commit

Permalink
Fixup DS issue with weakref
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 7, 2024
1 parent 127818f commit 90baa96
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,15 +1280,15 @@ def set_deepspeed_weakref(self):
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
"Please run `pip install transformers`."
)
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
ds_config["gradient_accumulation_steps"] = 1
if (
"train_micro_batch_size_per_gpu" not in ds_config
or ds_config["train_micro_batch_size_per_gpu"] == "auto"
):
ds_config["train_micro_batch_size_per_gpu"] = 1
if ds_config.get("train_batch_size", None) == "auto":
del ds_config["train_batch_size"]
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
ds_config["gradient_accumulation_steps"] = 1
if (
"train_micro_batch_size_per_gpu" not in ds_config
or ds_config["train_micro_batch_size_per_gpu"] == "auto"
):
ds_config["train_micro_batch_size_per_gpu"] = 1
if ds_config.get("train_batch_size", None) == "auto":
del ds_config["train_batch_size"]

if compare_versions("transformers", "<", "4.33"):
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config
Expand Down
36 changes: 36 additions & 0 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,42 @@ def test_ds_config_assertions(self):
in str(cm.exception)
)

def test_ds_zero3_no_init_autofill(self):
ds_config = {
"bf16": {
"enabled": True
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
"stage3_gather_16bit_weights_on_model_save": False,

"offload_optimizer": {
"device": "none"
},
"offload_param": {
"device": "none"
}
},
"gradient_clipping": 1.0,
"gradient_accumulation_steps": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 2000000
}
deepspeed_plugin = DeepSpeedPlugin(
hf_ds_config=ds_config,
zero3_init_flag=False,
)
with mockenv_context(**self.dist_env):
_ = Accelerator(deepspeed_plugin=deepspeed_plugin)
_ = AutoModelForCausalLM.from_pretrained("gpt2")

@parameterized.expand(stages, name_func=parameterized_custom_name_func)
def test_ds_config(self, stage):
deepspeed_plugin = DeepSpeedPlugin(
Expand Down

0 comments on commit 90baa96

Please sign in to comment.