Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused kernel compilation could get stuck #82

Closed
rhythmswing opened this issue Mar 14, 2021 · 17 comments
Closed

Fused kernel compilation could get stuck #82

rhythmswing opened this issue Mar 14, 2021 · 17 comments
Labels
bug Something isn't working stale No activity in 60 days on issue or PR

Comments

@rhythmswing
Copy link

rhythmswing commented Mar 14, 2021

Hi,

I've noticed that the program could get stuck at "using torch.float16 for parameters ...". I found that the problem was stuck at compilating fused_kernels and deleting megatron/fused_kernel/build seems to fix the problem. I'm not sure what causes this.
I'm posting this in hope it could be helpful.

@huangjundashuaige
Copy link

huangjundashuaige commented Mar 23, 2021

same problem, stuck at here.

using world size: 1, data-parallel-size: 1, tensor-model-parallel size: 1, pipeline-model-parallel size: 1 
using torch.float16 for parameters ...
^CTraceback (most recent call last):
  File "pretrain_gpt.py", line 149, in <module>
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
  File "/home/superbencher/Megatron-LM/megatron/training.py", line 87, in pretrain
    initialize_megatron(extra_args_provider=extra_args_provider,
  File "/home/superbencher/Megatron-LM/megatron/initialize.py", line 49, in initialize_megatron
    set_global_variables(extra_args_provider=extra_args_provider,
  File "/home/superbencher/Megatron-LM/megatron/global_vars.py", line 82, in set_global_variables
    args = _parse_args(extra_args_provider=extra_args_provider,
  File "/home/superbencher/Megatron-LM/megatron/global_vars.py", line 97, in _parse_args
    _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
  File "/home/superbencher/Megatron-LM/megatron/arguments.py", line 188, in parse_args
    fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
  File "/home/superbencher/Megatron-LM/megatron/fused_kernels/__init__.py", line 60, in load_scaled_upper_triang_masked_softmax_fusion_kernel
    scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
  File "/home/superbencher/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1079, in load
    return _jit_compile(
  File "/home/superbencher/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1306, in _jit_compile
    baton.wait()
  File "/home/superbencher/.local/lib/python3.8/site-packages/torch/utils/file_baton.py", line 42, in wait
    time.sleep(self.wait_seconds)
KeyboardInterrupt

@huangjundashuaige
Copy link

skip it by setting --no-scaled-masked-softmax-fusion.
do not how much it could affect end2end performance.

@bugface
Copy link

bugface commented Mar 23, 2021

delete Megatron-LM/megatron/fused_kernels/build/ and restart works for me.

Chen-Chang pushed a commit to Chen-Chang/Megatron-LM that referenced this issue May 18, 2021
@bottergpt
Copy link

bottergpt commented Dec 15, 2022

Same issue.
Actually, I got it run by removing the megatron/fused_kernels/build as suggested by @bugface,
but I am wondering if it is the right way to get it fixed?

@giacomocamposampiero
Copy link

Experiencing the same issue here, even if the observed behaviour was different on different nodes of the cluster (not sure if it was caused by different software stacks or different gpus).

Deleting megatron/fused_kernels/build did not work for me, and I only managed to solve the issue by completely dropping fused kernels as suggested by @huangjundashuaige. To update this solution, the arguments that do this in the current version are --no-masked-softmax-fusion and --no-bias-dropout-fusion.

@jon-barker jon-barker added the bug Something isn't working label Jun 29, 2023
@jon-barker
Copy link
Collaborator

Deleting /megatron/fused_kernels/build is recommended if you have upgraded CUDA versions or moved to different hardware. Those changes will not automatically be detected causing a rebuild of the kernels, which may be required.

We will be addressing this issue soon by moving to using the same prebuilt kernels from Apex and not requiring this custom kernel build step. I'll close this issue when that happens.

@github-actions
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Aug 29, 2023
@SefaZeng
Copy link

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Sep 14, 2023
Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label Nov 13, 2023
@MachineGunLin
Copy link

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

Same here. Did you solved this problem?

@github-actions github-actions bot removed the stale No activity in 60 days on issue or PR label Nov 27, 2023
@ZhenYangIACAS
Copy link

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?
@SefaZeng Same problem, have you fixed this problem?

@saforem2
Copy link

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

Same here. Did you solved this problem?

+1 same issue here

@mfdj2002
Copy link

mfdj2002 commented Feb 9, 2024

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

same here

@mfdj2002
Copy link

mfdj2002 commented Feb 9, 2024

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

Actually, it seems to be a problem with pytorch barrier, and simply setting NCCL_P2P_DISABLE=1 worked for me.
credit: https://discuss.pytorch.org/t/torch-distributed-barrier-doesnt-work-with-pytorch-2-0-and-backend-nccl/190232

@saforem2
Copy link

saforem2 commented Feb 9, 2024

awesome to hear, will try this, thanks!

@kduxin
Copy link

kduxin commented Mar 4, 2024

Got stuck when compiling the fused_kernels when training on multiple nodes. But it works well in a single node. Why?

Actually, it seems to be a problem with pytorch barrier, and simply setting NCCL_P2P_DISABLE=1 worked for me. credit: https://discuss.pytorch.org/t/torch-distributed-barrier-doesnt-work-with-pytorch-2-0-and-backend-nccl/190232

I met this problem on one of my nodes.
Working on that node along (NNODE=1), it does not work. To solve, I applied NCCL_P2P_DISABLE=1 to that node. This seems a hardware-related / BIOS setting issue.
But distributed training by excluding that node worked for me.

Copy link

github-actions bot commented May 3, 2024

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale No activity in 60 days on issue or PR label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale No activity in 60 days on issue or PR
Projects
None yet
Development

No branches or pull requests