Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync layer norm #271

Draft
wants to merge 38 commits into
base: thomas/test_different_layer_norm
Choose a base branch
from

Conversation

thomasw21
Copy link
Member

Force sync layer norms

@thomasw21 thomasw21 changed the base branch from main to thomas/test_different_layer_norm March 24, 2022 22:39
@stas00 stas00 mentioned this pull request Mar 25, 2022
# 2. test training from checkpoint: resume
# now do it again, this time resuming from the checkpoint
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it crashes on resume:

Traceback (most recent call last):
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/pretrain_gpt.py", line 245, in main
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 188, in pretrain
    iteration = train(forward_step_func,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 857, in train
    train_step(forward_step_func,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 441, in train_step
    loss = model[0].train_batch(data_iter=data_iterator)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 1363, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 1149, in _exec_optimizer_step
    self._take_model_step(lr_kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/engine.py", line 1787, in _take_model_step
    self.optimizer.step()
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/bf16_optimizer.py", line 239, in step
    assert all_groups_norm > 0.
AssertionError

Comment on lines 88 to 94
tp_world_size = mpu.get_tensor_model_parallel_world_size()
# TODO: hack in order to synchronize all layer norms despite them being unsynched
weight = mpu.reduce_from_tensor_model_parallel_region(self.weight) / tp_world_size
bias = mpu.reduce_from_tensor_model_parallel_region(self.bias) / tp_world_size

return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
input, weight, bias, self.normalized_shape,self.eps)
Copy link
Member

@stas00 stas00 Mar 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase, this is the main workaround that does the all_reduce mean on layer norm's weight+bias that we want to put in until we can fix the fp32 weights.

Comment on lines 90 to 93
weight = torch.clone(self.weight)
bias = torch.clone(self.bias)
weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size
bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size
Copy link
Member Author

@thomasw21 thomasw21 Mar 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00

Essentially the reduce is an in-place operator, which means at each forward pass, self.weight was updated with the sum of all the weights of all tp_ranks. We could try thinking of a better fix by doing a average reduce, but I'm scared back propagation doesn't play well with this in place logic.

New test fails with:

E               raise StopIteration
E           StopIteration

This is more expected since the previous run should have consumed all the tokens. Going to update #272 and restart the training.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we extend:

def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())

to support an optional ReduceOp.AVG

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is tricky. The reason why is this means that we need to implement custom backward function (since you compute the average, the gradient needs to be divided by the tp world size).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't think we save much compute by supporting that.

adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
* Enable universal ckpting

* Update run scripts

* Address PR feedback

* Remove line

* Fix white lines

* Remove redudant changes

* Apply to gpt_model only

* Code cleanup

* Code cleanup

* Update training.py

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>

* Update training.py

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>

* Log loss_scale only valid for fp16

* Add README and bf16 scripts

* Visualization docsts

* Support older DS

* Handle uni_ckpt import error

* Revert changes

---------

Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants