Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF Llama 3.2 1B slowness (training) #1506

Open
t-vi opened this issue Dec 2, 2024 · 9 comments
Open

HF Llama 3.2 1B slowness (training) #1506

t-vi opened this issue Dec 2, 2024 · 9 comments
Labels

Comments

@t-vi
Copy link
Collaborator

t-vi commented Dec 2, 2024

The following repro for training with batch size 1, seq len 2048 has thunder+nvfuser being significantly slower than torch.compile.

import torch
from transformers.models.llama import LlamaForCausalLM, LlamaConfig

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}

args = dict(
    input_ids=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
    labels=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
)

config = LlamaConfig(**LLAMA_3_2_1B_CFG)

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16)


res = model(**args)
res.loss.backward()
import thunder
from thunder.transforms.cudagraph import CUDAGraphTransform
jm = thunder.jit(model,
                #executors=('apex', 'cudnn', 'sdpa', 'torchcompile_cat', 'nvfuser'),
                )

res = jm(**args)
res.loss.backward()

jm2 = thunder.jit(model,
                executors=('apex', 'cudnn', 'sdpa', 'torchcompile'),
                )

res = jm2(**args)
res.loss.backward()

cm = torch.compile(model)
res = cm(**args)
res.loss.backward()

Timings:

%timeit res = jm(**args); res.loss.backward(); torch.cuda.synchronize()
%timeit res = jm2(**args); res.loss.backward(); torch.cuda.synchronize()
%timeit res = model(**args); res.loss.backward(); torch.cuda.synchronize()
%timeit res = cm(**args); res.loss.backward(); torch.cuda.synchronize()

Gives on a Studio with L40s (reodered):

Eager: 136ms
Thunder with default executors (including NVFuser): 125ms
Thunder with apex, cudnn, sdpa, torchcompile (no NVFuser): 117ms
Torch Compile: 105ms

cc @apaz-cli

@t-vi t-vi added high priority performance huggingface For supporting HF models labels Dec 2, 2024
@kevinstephano
Copy link
Collaborator

I cannot get this example to properly run. I get the error with TOT Thunder. Is this repro missing anything or do I need a Thunder patch to run, by chance?

usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at
 /opt/pytorch/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:674.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py:108: UserWarning: To use flash-attn v3, please use the following commands to install:
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py
  warnings.warn(
Traceback (most recent call last):
  File "/workspace/test_training_orig.py", line 58, in <module>
    res = jm(**args)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 80, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 772, in wrapped
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 822, in fn_                                                                                                                                                                     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 754, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 236, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 529, in get_computation_and_inputs
    jit_results: TraceResults = thunder_general_jit(
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1747, in thunder_general_jit
    process_recorded_modifications(ctx, epilogue_trace)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1622, in process_recorded_modifications
    assert isinstance(value.value, Proxy)
AssertionError

@riccardofelluga
Copy link
Collaborator

@t-vi what version of nvFuser did you record the timings with? I cannot reproduce this slowdown on H100 with nvFuser at 0.2.23+git97544c3, Thunder at 60f3ee1ec536ee8d6fdef503af54525e0a3978a4 and torch at 2.6.0a0+gitf0f6144. After warmup I see:

%timeit res = jm(**args); res.loss.backward(); torch.cuda.synchronize()
39.2 ms ± 298 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit res = jm2(**args); res.loss.backward(); torch.cuda.synchronize()
39.8 ms ± 35.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit res = model(**args); res.loss.backward(); torch.cuda.synchronize()
53.6 ms ± 8.8 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit res = cm(**args); res.loss.backward(); torch.cuda.synchronize()
38 ms ± 4.07 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@kevinstephano maybe try in the container 20241120 it seems to be working there.

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 4, 2024

I used the latest pip versions (pt 2.5.1) and L40s.
I upgraded to nvfuser_cu121_torch25==0.2.23.dev20241201 aka 0.2.23+gitc154e90 just to be sure, the timings remain the same.
Given that this is newer than your git version, has there been a performance regression or have the pip release particularly slow flags?

@kevinstephano note that the hf version I used for this is the same as last week (4.46.2).

@riccardofelluga
Copy link
Collaborator

Ok, I tested again, i created the setup from a new environment installing requirements for thunder first and then adding the specific versions you mentioned above, in particular installed nvFuser from pip nvfuser_cu121_torch25==0.2.23.dev20241201.

I still cannot replicate the timings as on A6000 Ada, for python 3.12.7, torch==2.5.1 and transformers==4.46.2 this are the results:

%timeit res = jm(**args); res.loss.backward(); torch.cuda.synchronize()
156 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = jm2(**args); res.loss.backward(); torch.cuda.synchronize()
159 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = model(**args); res.loss.backward(); torch.cuda.synchronize()
175 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit res = cm(**args); res.loss.backward(); torch.cuda.synchronize()
162 ms ± 990 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So jm2 which is without nvFuser is slower than jm which is the one with nvFuser

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 4, 2024

@riccardofelluga Hm, strange.
I could see two potential reasons:

  • Could it be that your setup has a different CPU to GPU performance ratio than the L40s?
  • 3.10 vs. 3.12, I think some details have changed the perf characteristics quite a bit, maybe we hit something lucky? I checked and get exactly the same numbers with 3.12
    Other ideas would be very welcome...

@riccardofelluga
Copy link
Collaborator

Ok I seem to be able to reproduce the numbers in lightning studios, while I am looking into it, is there a reason why you added "torchcompile" instead of "torchcompile_cat" in the list of the executors?

In the snippet the compared executors list are:

  • [cudnn -> sdpa -> torchcompile_cat -> nvfuser -> torch -> python]
  • [cudnn -> sdpa -> torchcompile -> torch -> python]

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 4, 2024

Yeah, the reason is that it might be faster. :) torchcompile_cat is specifically to leave things to nvfuser, torchcompile is "fuse what you can"

@kevinstephano
Copy link
Collaborator

The numbers I saw were virtually identical on DGX H100 as the ones reported by @riccardofelluga. I also measured on L40.

DGX H100-80GB Results:

Execution Type Wall Clock Time (ms)
torch-eager 53.775
Thunder-nvFuser 37.977
Thunder-torch.compile 38.979
Thunder-torch 96.625
torch.compile 37.963

L40 Results:

Execution Type Wall Clock Time (ms)
torch-eager 202.513
Thunder-nvFuser 188.717
Thunder-torch.compile 195.734
Thunder-torch 295.482
torch.compile 186.761

@kevinstephano
Copy link
Collaborator

The problem for training on the litgpt studio L40's is the same as for inference. I will comment on the inference problem more thoroughly in #1467.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants