Skip to content

Commit

Permalink
Update distopt API for coalesced NCCL calls (#6886)
Browse files Browse the repository at this point in the history
* Update distopt API for coalesced NCCL calls

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update comment

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] authored Jul 3, 2023
1 parent b0e5bf3 commit 0b6e4e6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ RUN apt-get update && \
WORKDIR /workspace/

WORKDIR /tmp/
# TODO: Remove once this Apex commit (2/24/23) is included in PyTorch
# TODO: Remove once this Apex commit (5/12/23) is included in PyTorch
# container
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout 57057e2fcf1c084c0fcc818f55c0ff6ea1b24ae2 && \
git checkout 8b7a1ff183741dd8f9b87e7bafd04cfde99cea28 && \
pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

# uninstall stuff from base container
Expand All @@ -75,7 +75,7 @@ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-chec
# install flash attention dependencies
RUN pip install flash-attn
# pinned triton version for flash-attention https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3
RUN pip install triton==2.0.0.dev20221202
RUN pip install triton==2.0.0.dev20221202

# install k2, skip if installation fails
COPY scripts /tmp/nemo/scripts/
Expand Down
12 changes: 6 additions & 6 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from apex.contrib.optimizers.distributed_fused_adam import (
DistributedFusedAdam,
_coalescing_manager,
_coalescing_manager_append_work,
_disable_pre_forward_hook,
)
from megatron.core import parallel_state
Expand Down Expand Up @@ -173,16 +174,15 @@ def _fp32_optim_grad_sync(self):
for model_param, main_param in self._fp32_optim_main_params.items():
if model_param.grad is not None:
main_param.grad += model_param.grad.detach()
sync_requests = []
with _coalescing_manager(self.process_group, self.device, sync_requests):
with _coalescing_manager(self.process_group, self.device, async_ops=True) as cm:
for main_param in self._fp32_optim_main_params.values():
sync_requests.append(
_coalescing_manager_append_work(
cm,
torch.distributed.all_reduce(
main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True,
)
),
)
for req in sync_requests:
req.wait()
cm.wait()
self._fp32_optim_grad_sync_needed = False

def zero_grad(self, *args, **kwargs):
Expand Down

0 comments on commit 0b6e4e6

Please sign in to comment.