Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing deepspeed multi-node launcher #514

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from ...utils import ComputeEnvironment, DistributedType, is_deepspeed_available, is_transformers_available
from ...utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from .config_args import ClusterConfig
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool

Expand Down Expand Up @@ -144,6 +145,51 @@ def get_cluster_input():
"Please run `pip3 install transformers`."
)

if num_machines > 1:
launcher_query = "Which Type of launcher do you want to use "
for i, launcher in enumerate(DEEPSPEED_MULTINODE_LAUNCHERS):
launcher_query += f"[{i}] {launcher}, "
launcher_query = launcher_query[:-2] + ")? [0]: "
deepspeed_config["deepspeed_multinode_launcher"] = _ask_field(
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
launcher_query,
lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)],
default=DEEPSPEED_MULTINODE_LAUNCHERS[0],
)

if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
deepspeed_config["deepspeed_hostfile"] = _ask_field(
"DeepSpeed configures multi-node compute resources with hostfile. "
"Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; "
"for more information please refer official [documentation]"
"(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). "
"Please specify the location of hostfile: ",
lambda x: str(x),
)

is_exclusion_filter = _ask_field(
"Do you want to specify exclusion filter string? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if is_exclusion_filter:
deepspeed_config["deepspeed_exclusion_filter"] = _ask_field(
"DeepSpeed exclusion filter string: ",
lambda x: str(x),
)

is_inclusion_filter = _ask_field(
"Do you want to specify inclusion filter string? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if is_inclusion_filter:
deepspeed_config["deepspeed_inclusion_filter"] = _ask_field(
"DeepSpeed inclusion filter string: ",
lambda x: str(x),
)

fsdp_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_fsdp = _ask_field(
Expand Down
83 changes: 69 additions & 14 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_deepspeed_available,
is_sagemaker_available,
)
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from accelerate.utils.dataclasses import SageMakerDistributedType


Expand Down Expand Up @@ -109,6 +110,30 @@ def launch_command_parser(subparsers=None):
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
"Only applicable with DeepSpeed ZeRO Stage-3.",
)
parser.add_argument(
"--deepspeed_hostfile",
default=None,
type=str,
help="DeepSpeed hostfile for configuring multi-node compute resources.",
)
parser.add_argument(
"--deepspeed_exclusion_filter",
default=None,
type=str,
help="DeepSpeed exclusion filter string when using mutli-node setup.",
)
parser.add_argument(
"--deepspeed_inclusion_filter",
default=None,
type=str,
help="DeepSpeed inclusion filter string when using mutli-node setup.",
)
parser.add_argument(
"--deepspeed_multinode_launcher",
default=None,
type=str,
help="DeepSpeed multi-node launcher to use.",
)
parser.add_argument(
"--use_fsdp",
default=False,
Expand Down Expand Up @@ -312,20 +337,42 @@ def deepspeed_launcher(args):
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
cmd = ["deepspeed", "--no_local_rank"]
if args.num_machines > 1:
cmd.extend(
[
"--num_gpus",
str(args.num_processes // args.num_machines),
"--num_nodes",
str(args.num_machines),
"--node_rank",
str(args.machine_rank),
"--master_addr",
args.main_process_ip,
"--master_port",
str(args.main_process_port),
]
)
if args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]:
cmd = get_launch_prefix()
cmd.extend(
[
"--nproc_per_node",
str(args.num_processes // args.num_machines),
"--nnodes",
str(args.num_machines),
"--node_rank",
str(args.machine_rank),
"--master_addr",
args.main_process_ip,
"--master_port",
str(args.main_process_port),
]
)
else:
cmd.extend(
["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]
)
if args.deepspeed_exclusion_filter is not None:
cmd.extend(
[
"--exclude",
str(args.deepspeed_exclusion_filter),
]
)
elif args.deepspeed_inclusion_filter is not None:
cmd.extend(
[
"--include",
str(args.deepspeed_inclusion_filter),
]
)
else:
cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)])
else:
cmd.extend(["--num_gpus", str(args.num_processes)])

Expand All @@ -350,6 +397,7 @@ def deepspeed_launcher(args):
warnings.warn('--fp16 flag is deprecated. Use "--mixed_precision fp16" instead.', DeprecationWarning)
mixed_precision = "fp16"

current_env["PYTHONPATH"] = sys.executable
stas00 marked this conversation as resolved.
Show resolved Hide resolved
current_env["MIXED_PRECISION"] = str(mixed_precision)
current_env["USE_DEEPSPEED"] = "true"
current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
Expand All @@ -361,6 +409,13 @@ def deepspeed_launcher(args):
current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file).lower()

if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
with open(".deepspeed_env", "a") as f:
for key, value in current_env.items():
if ";" in value or " " in value:
continue
f.write(f"{key}={value}\n")

process = subprocess.Popen(cmd, env=current_env)
process.wait()
if process.returncode != 0:
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
SAGEMAKER_PYTHON_VERSION = "py38"
SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich"]

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}