Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable non-distributed training and MPS support #769

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .exceptions import OLMoCheckpointError
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import SingleAccelerator as SINGLE
from .torch_util import (
barrier,
gc_cuda,
Expand Down Expand Up @@ -645,7 +646,7 @@ def save_checkpoint(
self._write_optim_dict(
optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
)
elif isinstance(dist_model, DDP):
elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE):
# _write_model_dict and _write_optim_dict only write checkpoints for rank 0
# First, get the model state dict from DDP wrapped model
model_state_dict = dist_model.module.state_dict()
Expand All @@ -660,7 +661,7 @@ def save_checkpoint(
)
else:
log.info(
"`FullCheckpointer.save_checkpoint` only supported for FSDP and DDP distributed strategies!"
"`FullCheckpointer.save_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!"
)

# Save trainer state.
Expand Down Expand Up @@ -757,7 +758,7 @@ def restore_checkpoint(
torch.cuda.empty_cache()
barrier()
del optim_state_dict_to_load
elif isinstance(dist_model, DDP):
elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE):
# Load model state.
with torch.no_grad():
state_dict_to_load = load_state_dict(
Expand All @@ -773,11 +774,12 @@ def restore_checkpoint(
optim.load_state_dict(optim_state_dict_to_load)

gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
barrier()
else:
raise NotImplementedError(
"`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!"
"`FullCheckpointer.restore_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!"
)

# Load other state.
Expand Down
4 changes: 4 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,10 @@ class DistributedStrategy(StrEnum):
Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks.
"""

single = "single"
"""
Train on a single device, i.e., do not distribute trainig. For development and debugging.
"""

class DDPGradSyncMode(StrEnum):
batch = "batch"
Expand Down
8 changes: 8 additions & 0 deletions olmo/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,11 @@ def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor:
torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32),
]
)

class SingleAccelerator(torch.nn.Module):
process_group = None
def __init__(self, module: torch.nn.Module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
10 changes: 7 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def trainer_state_dict(self) -> Dict[str, Any]:
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
"cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
"mps": torch.mps.get_rng_state() if torch.mps.is_available() else None,
},
}

Expand Down Expand Up @@ -430,7 +431,10 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
random.setstate(rng_state["python"])
np.random.set_state(rng_state["numpy"])
torch.set_rng_state(rng_state["torch"])
torch.cuda.set_rng_state(rng_state["cuda"])
if rng_state.get("cuda", None) is not None:
torch.cuda.set_rng_state(rng_state["cuda"])
if rng_state.get("mps", None) is not None:
torch.mps.set_rng_state(rng_state["mps"])

def _save_checkpoint(
self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
Expand Down Expand Up @@ -1247,7 +1251,7 @@ def on_trace_ready(p):
stop_at = min(stop_at, self.global_step + extra_steps)

# Maybe save sharded checkpoint.
if self.cfg.distributed_strategy != DistributedStrategy.ddp:
if self.cfg.distributed_strategy == DistributedStrategy.fsdp:
if save_checkpoints and (
cancel_initiated
or (
Expand Down
44 changes: 27 additions & 17 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from olmo.exceptions import OLMoCliError, OLMoConfigurationError
from olmo.model import OLMo
from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler
from olmo.torch_util import SingleAccelerator as SINGLE
from olmo.torch_util import (
barrier,
get_default_device,
Expand Down Expand Up @@ -65,9 +66,14 @@ def main(cfg: TrainConfig) -> None:
barrier()

# Set CUDA device.
torch.cuda.set_device(f"cuda:{get_local_rank()}")
torch.cuda.empty_cache()
device = torch.device("cuda")
if torch.cuda.is_available():
torch.cuda.set_device(f"cuda:{get_local_rank()}")
torch.cuda.empty_cache()
device = torch.device("cuda")
elif torch.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

# Fill some configuration options.
cfg.model.precision = cfg.precision
Expand Down Expand Up @@ -211,8 +217,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
param_init_fn=param_init_fn,
**hybrid_sharding_fsdp_kwargs,
)
elif cfg.distributed_strategy is None:
raise NotImplementedError("Single accelerator training not implemented yet!")
elif cfg.distributed_strategy == DistributedStrategy.single:
param_init_fn = None
dist_model = SINGLE(olmo_model.to(device))

# when param_init_fn is None, FSDP will call reset_parameters() automatically
if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp:
Expand Down Expand Up @@ -287,7 +294,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
cfg.reset_optimizer_state = False

if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
if cfg.distributed_strategy == DistributedStrategy.ddp:
if cfg.distributed_strategy in [DistributedStrategy.ddp, DistributedStrategy.single]:
checkpoint_type = CheckpointType.unsharded

if cfg.save_interval_unsharded is None:
Expand Down Expand Up @@ -363,17 +370,20 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
print(f"failed to set multiprocessing start method: {e}")
log.info(f"Multiprocessing start method set to '{mp.get_start_method()}'")

# Set CUDA device.
torch.cuda.set_device(f"cuda:{get_local_rank()}")

# Initialize process group.
device_as_string = f"cuda:{get_local_rank()}"
torch.cuda.set_device(
device_as_string
) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have.
dist.init_process_group(
backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string)
)
if torch.cuda.is_available():
# Set CUDA device.
torch.cuda.set_device(f"cuda:{get_local_rank()}")

# Initialize process group.
device_as_string = f"cuda:{get_local_rank()}"
torch.cuda.set_device(
device_as_string
) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have.
dist.init_process_group(
backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string)
)
else:
dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30))
log.info("Process group initialized")

prepare_cli_environment()
Expand Down