Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync layer norm #271

Draft
wants to merge 38 commits into
base: thomas/test_different_layer_norm
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
07ccb3d
Better
thomasw21 Mar 24, 2022
391ed48
Force synchronize the layer norms parameters across all TP
thomasw21 Mar 24, 2022
98d0e7c
import mpu
stas00 Mar 25, 2022
279a77e
use the bf16 branch for testing
stas00 Mar 25, 2022
87a9dba
`torch.testing.assert_equal` didn't make it (#273)
stas00 Mar 25, 2022
dbb5914
Merge remote-tracking branch 'origin/main' into thomas/fix_layer_norm
stas00 Mar 25, 2022
70f91f8
bf16 comms requite pt-1.11
stas00 Mar 25, 2022
835a3e5
already part of the function
stas00 Mar 25, 2022
37795a9
reproduce the crashing on resume
stas00 Mar 25, 2022
3ec65f7
run just the test we want for now
stas00 Mar 25, 2022
8271d41
all_reduce is an in_place operation
thomasw21 Mar 25, 2022
b418b47
Make a test that TP reshaping works
thomasw21 Mar 25, 2022
4b7207b
Woops
thomasw21 Mar 25, 2022
3bc5824
Woops
thomasw21 Mar 25, 2022
05c99db
Woops
thomasw21 Mar 25, 2022
55e10c6
Woops
thomasw21 Mar 25, 2022
2ab8a3a
Woops
thomasw21 Mar 25, 2022
d357839
Woops
thomasw21 Mar 25, 2022
5fb231c
Woops
thomasw21 Mar 25, 2022
cc7ff45
Woops
thomasw21 Mar 25, 2022
7cdb1be
Woops
thomasw21 Mar 25, 2022
4574ec9
Fix load issue
thomasw21 Mar 25, 2022
04e89d1
Woops
thomasw21 Mar 25, 2022
e943100
Fix checkpoint path
thomasw21 Mar 25, 2022
09cead3
Test that force sync will allow TP changes
thomasw21 Mar 25, 2022
77abee6
Nit
thomasw21 Mar 25, 2022
64a62c8
Now that we have a force sync mechanism, let's try to reproduce
thomasw21 Mar 29, 2022
0b7afcc
Compare model_states_rank
thomasw21 Mar 29, 2022
ce01733
test
thomasw21 Mar 29, 2022
89ab0b7
Row column bias should be synchronized as well
thomasw21 Mar 29, 2022
42997b2
New list of matching embeddings
thomasw21 Mar 29, 2022
e0ef168
Figure out why state differs
thomasw21 Mar 29, 2022
1fc4fe8
Test for final weight
thomasw21 Mar 29, 2022
7ebbed1
Test that torch_rng_state
thomasw21 Mar 29, 2022
2c49216
Fix non matching torch_rng_state for tp_rank=0
thomasw21 Mar 30, 2022
007ecb4
Update test
thomasw21 Mar 31, 2022
c3844b5
I'm surprised one can apply inplace operation here
thomasw21 Mar 31, 2022
189f054
Test out the loss from the fp32 weights and optimizer states
thomasw21 Apr 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ jobs:
pip install pytest-timeout

- name: Run tests
run: pytest --timeout=600 tests
# run: pytest --timeout=600 tests
# run just the test we want for now
run: pytest --timeout=600 tests/test_training.py::MegDSTestTraining::test_layer_norm_consistent_0_bf16

stop-runner:
name: Stop self-hosted EC2 runner
Expand Down
21 changes: 8 additions & 13 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron import mpu

global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None
Expand Down Expand Up @@ -84,19 +85,13 @@ def reset_parameters(self):


def forward(self, input):
weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())]
torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group())
biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())]
torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group())
if any(torch.any(weight != self.weight) for weight in weights):
if mpu.get_tensor_model_parallel_rank() == 0:
print("Weight sync failed")
print(weights)
if any(torch.any(bias != self.bias) for bias in biases):
if mpu.get_tensor_model_parallel_rank() == 0:
print("Bias sync failed")
print(biases)
tp_world_size = mpu.get_tensor_model_parallel_world_size()
# TODO: hack in order to synchronize all layer norms despite them being unsynched
weight = torch.clone(self.weight)
bias = torch.clone(self.bias)
weight = mpu.reduce_from_tensor_model_parallel_region(weight) / tp_world_size
bias = mpu.reduce_from_tensor_model_parallel_region(bias) / tp_world_size
Copy link
Member Author

@thomasw21 thomasw21 Mar 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00

Essentially the reduce is an in-place operator, which means at each forward pass, self.weight was updated with the sum of all the weights of all tp_ranks. We could try thinking of a better fix by doing a average reduce, but I'm scared back propagation doesn't play well with this in place logic.

New test fails with:

E               raise StopIteration
E           StopIteration

This is more expected since the previous run should have consumed all the tokens. Going to update #272 and restart the training.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we extend:

def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())

to support an optional ReduceOp.AVG

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is tricky. The reason why is this means that we need to implement custom backward function (since you compute the average, the gradient needs to be divided by the tp world size).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't think we save much compute by supporting that.


return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)
input, weight, bias, self.normalized_shape,self.eps)

8 changes: 4 additions & 4 deletions megatron/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def get_gpu_count():
return 0

def torch_assert_equal(actual, expected, **kwargs):
# assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_equal"):
return torch.testing.assert_equal(actual, expected, **kwargs)
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
if hasattr(torch.testing, "assert_close"):
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
else:
return torch.allclose(actual, expected, rtol=0.0, atol=0.0)

Expand Down Expand Up @@ -886,4 +886,4 @@ def flatten_arguments(args):

Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
"""
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ pybind11
regex
six
tensorboard
torch>=1.7
torch>=1.11
transformers
DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git
# for now using this branch for bf16 work
DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git@olruwase/bf16-updates
# versions from HF transformers
black==21.4b0
isort>=5.5.4
18 changes: 16 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ def test_layer_norm_consistent(self, variation):
execute_subprocess_async(cmd, env=self.get_env())

checkpoints = ["global_step10", "global_step20"]

# Check transformer layer norm
keys_to_compare = ["input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias"]
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [3,4]]
for checkpoint in checkpoints:
Expand All @@ -691,8 +693,9 @@ def test_layer_norm_consistent(self, variation):
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files]
ref = weights[0]
for weight in weights[1:]:
torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False)
torch_assert_equal(ref, weight, check_device=False)

# Check embed layer norm
keys_to_compare = ["word_embeddings.norm.weight"]
files_to_compare = [[f"layer_{layer_id:02d}-model_{tp:02d}-model_states.pt" for tp in range(num_gpus)] for layer_id in [1]]
for checkpoint in checkpoints:
Expand All @@ -702,4 +705,15 @@ def test_layer_norm_consistent(self, variation):
weights = [torch.load(os.path.join(checkpoint_path, file))[key] for file in files]
ref = weights[0]
for weight in weights[1:]:
torch_assert_equal(ref, weight, rtol=0.0, atol=0.0, check_device=False)
torch_assert_equal(ref, weight, check_device=False)

# 2. test training from checkpoint: resume
# now do it again, this time resuming from the checkpoint
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it crashes on resume:

Traceback (most recent call last):
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
    return f(*args, **kwargs)
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/pretrain_gpt.py", line 245, in main
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 188, in pretrain
    iteration = train(forward_step_func,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 857, in train
    train_step(forward_step_func,
  File "/mnt/nvme0/code/huggingface/Megatron-DeepSpeed-master-4/megatron/training.py", line 441, in train_step
    loss = model[0].train_batch(data_iter=data_iterator)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
    self._exec_schedule(sched)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 1363, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/pipe/engine.py", line 1149, in _exec_optimizer_step
    self._take_model_step(lr_kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/engine.py", line 1787, in _take_model_step
    self.optimizer.step()
  File "/home/stas/anaconda3/envs/py38-pt111/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme0/code/github/00optimize/deepspeed/deepspeed/runtime/bf16_optimizer.py", line 239, in step
    assert all_groups_norm > 0.
AssertionError


# test checkpoint loading
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)

# test reports
self.assertIn("consumed samples", cs.out)