Skip to content

Commit

Permalink
fix zero1 regression
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Jan 17, 2025
1 parent cdfb54c commit 6d030c4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
18 changes: 18 additions & 0 deletions deepspeed/runtime/tensor_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ class TPTrainingConfig(DeepSpeedConfigModel):
"""

injection_policy_tuple: Optional[tuple] = None
#The following parameters are required by autoTP parser.
########################################
keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
"""

replace_with_kernel_inject: bool = Field(False, alias="kernel_inject")
"""
Set to true to inject inference kernels for models such as, Bert, GPT2,
GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two
linear layers as a tuple:
`(attention_output projection, transformer output projection)`
"""
########################################


def get_tensor_parallel_config(ds_config):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/tensor_parallel/tp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, model, tp_size, dtype):

# Synchronize random number generator state across devices
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0, self.tp_config.tp_group)
dist.broadcast(_rng_state, groups.get_tensor_model_parallel_src_rank(), self.tp_config.tp_group)
get_accelerator().set_rng_state(_rng_state.cpu())

# Apply injection policies
Expand Down

0 comments on commit 6d030c4

Please sign in to comment.