Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw warning if users are using old pytorch version that not including DTensor strided sharding #507

Merged
merged 10 commits into from
Aug 13, 2024
Merged
21 changes: 9 additions & 12 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_strided_sharding_enabled


def parallelize_llama(
Expand Down Expand Up @@ -83,6 +84,7 @@ def parallelize_llama(
reduce_dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_reduce
],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
)
else:
Expand Down Expand Up @@ -289,6 +291,7 @@ def apply_fsdp(
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
):
"""
Expand All @@ -297,6 +300,12 @@ def apply_fsdp(
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand All @@ -313,18 +322,6 @@ def apply_fsdp(
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)

if pp_enabled:
# TODO
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
# removed after strided sharding is landed in DCP.
for module in model.modules():
assert len(module._load_state_dict_pre_hooks) <= 1
module._load_state_dict_pre_hooks.clear()
assert len(module._state_dict_pre_hooks) <= 1
module._state_dict_pre_hooks.clear()
XilunWu marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Applied FSDP to the model")


Expand Down
30 changes: 30 additions & 0 deletions torchtitan/parallelisms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

from torchtitan.logging import logger


def check_strided_sharding_enabled() -> None:
# Correct 2D/3D DCP usage requires DTensor's strided sharding in PR
# https://github.com/pytorch/pytorch/pull/130760. This function checks if users'
# PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is
XilunWu marked this conversation as resolved.
Show resolved Hide resolved
# included when 2D/3D DCP is used.
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the commit hash is newer than 2024-08-09
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
"(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch "
"for correct 2D/3D DCP usage."
)
elif torch.__version__ < "2.5.0.dev20240809":
# the nightly build pytorch was built before 2024-08-09
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
"2.5.0.dev20240809. Please upgrade a newer version to include the change "
"made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D "
"DCP usage."
)
Loading