Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] hotfix backward hook #6053

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
with self._hook_context():
return super().forward(*args, **kwargs)

def unwrap(self):
Expand All @@ -229,12 +229,8 @@ def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)

def _wait_all_gather(self):
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()


def get_param_info(optim: Optimizer):
Expand Down Expand Up @@ -317,7 +313,8 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""

# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -540,7 +537,8 @@ def backward(self, loss: Tensor, *args, **kwargs):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)

if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
Expand Down Expand Up @@ -683,6 +681,7 @@ def __init__(
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
self.model = model
self.param_info = param_info
Expand Down Expand Up @@ -712,6 +711,8 @@ def __init__(
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)

def sync_dp_grads(self):
Expand Down Expand Up @@ -1206,6 +1207,7 @@ def __init__(
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)

self.max_norm = max_norm
Expand Down Expand Up @@ -1371,7 +1373,7 @@ def execute_pipeline(
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()

with ctx, model._wait_all_gather():
with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
Expand Down
8 changes: 5 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
with ctx:
with self._hook_context():
return super().forward(*args, **kwargs)

def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)

def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
Expand Down Expand Up @@ -520,7 +522,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
6 changes: 6 additions & 0 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"

import torch
import torch.distributed as dist

from colossalai.accelerator import get_accelerator
Expand Down Expand Up @@ -64,6 +65,11 @@ def launch(

set_seed(seed)

try:
torch._dynamo.config.optimize_ddp = world_size > 1
except AttributeError:
pass

if verbose:
logger = get_dist_logger()
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
Expand Down
8 changes: 6 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
fp8_communication: bool = False,
backward_context=None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)

Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
self._backward_context = backward_context

# gradient clipping
self._clip_grad_norm = clip_grad_norm
Expand Down Expand Up @@ -429,7 +431,9 @@ def backward(self, loss, retain_graph=False):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

loss.backward(retain_graph=retain_graph)
ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx:
loss.backward(retain_graph=retain_graph)

if not self.require_grad_sync:
return
Expand Down
Loading