Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dtype Mismatch in torch.addmm within ops/fused_linear_cross_entropy.py in AMP training. #502

Merged
merged 4 commits into from
Dec 29, 2024

Conversation

DandinPower
Copy link
Contributor

Summary

This PR addresses a dtype mismatch error that I encountered while using PyTorch AMP to train a Llama3 model. After reviewing previous discussions, such as closed issue #305 and PR #318, conducting my own tests, and performing a complete analysis of the problem, I found that there is still a possibility of encountering a dtype mismatch if the bias is None during FLCE computation. The detailed observation and analysis of the issue can be found in issue #501.

This PR aims to:

  1. Enhance the test cases to reproduce the mismatch error.
  2. Resolve the bug by ensuring the correct dtype is used, without affecting the behavior in other scenarios.

Testing Done

  • Hardware Type: RTX-4090-24G
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
$ make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
========================= test session starts =========================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.6.1, rerunfailures-15.0
collected 965 items                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ 99%]
test/transformers/test_transformers.py::test_import_from_root PASSED                                                                                                                                                                                                                                          [ 99%]
test/triton/test_triton_monkey_patch.py::test_import_from_root PASSED                                                                                                                                                                                                                                         [ 99%]
test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager PASSED                                                                                                                                                                                                                              [100%]

========================= 750 passed, 215 skipped, 41 warnings in 32.40s =========================

$ make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
==================================================== test session starts =====================================================
platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0
rootdir: /mnt/sda1/latest_liaw/open-source/Liger-Kernel
configfile: pyproject.toml
plugins: xdist-3.6.1, rerunfailures-15.0
collecting ... 
---------------------------------------------------- live log collection -----------------------------------------------------
INFO     datasets:config.py:54 PyTorch version 2.5.1 available.
collected 17 items                                                                                                           

test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype14-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED                                                                                                                                                                [ 88%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype15-0.001-0.01-0.1-0.01-0.01-0.01] PASSED                                                                                                                                                                       [ 94%]
test/convergence/test_mini_models_with_logits.py::test_mini_model[mini_gemma2-32-0.0001-dtype16-1e-08-0.0001-0.005-1e-05-0.005-1e-05] PASSED                                                                                                                                                                  [100%]

========================= 17 passed, 1 warning in 60.39s (0:01:00) =========================


$ make checkstyle
ruff check . --fix; ruff_check_status=$?; \
ruff format .; ruff_format_status=$?; \
if [ $ruff_check_status -ne 0 ] || [ $ruff_format_status -ne 0 ]; then \
        exit 1; \
fi
All checks passed!
124 files left unchanged

…x dtype mismatch bug in fused_linear_cross_entropy.
@@ -59,6 +59,10 @@ def fused_linear_cross_entropy_forward(
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
if logits_chunk.dtype != weight.dtype:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is dtype comparison necessary? According to aten source code, I think .to(dtype) already does a dtype check implicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right! After reading the source code you provided, conducting some small experiments, and profiling the behavior of PyTorch, it is confirmed that it performs a dtype check implicitly and bypasses the casting operation if the dtype is already the same.

Here is an example of a bfloat16 tensor attempting .to(dtype=torch.float32):
image

And here is an example of a float32 tensor attempting .to(dtype=torch.float32):
image

I will fix this by removing the unnecessary comparison.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Tcc0403, thanks for the review! I have merged the latest main branch and made a new commit to remove the unnecessary dtype comparison. I also re-ran make test, make checkstyle, and make test-convergence, and all tests have passed as before.

Copy link
Collaborator

@Tcc0403 Tcc0403 Dec 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious how it would affect the speed and memory if we move type conversion into addmm(). https://github.com/linkedin/Liger-Kernel/pull/502/files#diff-d7a61f0e3b11dccec3ffab797033897bbc73d69abff5538f6ff11da850e38f63R120

if grad_weight is not None:
            torch.addmm(
                input=grad_weight,
                mat1=logits_chunk.t().to(_input_chunk.dtype), # here
                mat2=_input_chunk,
                out=grad_weight,
                alpha=alpha,
                beta=1.0,
            )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks interesting! I can conduct some experiments to profile the speed and memory.

@DandinPower
Copy link
Contributor Author

DandinPower commented Dec 28, 2024

Hi @Tcc0403,
It is an interesting thing that we have three types of scenarios here

  1. Convert after logits_chunk is created, so logits_chunk will remain FP32 all the time (current implementation).
  2. Convert only before addmm is invoked, so logits_chunk will only become FP32 right before the addmm (the scenario you asked about).
  3. If bias is not None, the addition operation will keep logits_chunk as FP32 all the time, the same as scenario (1). No difference for convert at begin or before addmm.

Let’s start with the latency part. TL;DR, there is only a negligible difference, and it is even hard to test because the margin of error might be larger than the actual differences. Now let me start analyzing why there is only a very small difference.

The first difference is in liger_cross_entropy_kernel. When passing logits_chunk into liger_cross_entropy_kernel, scenario (1) passes FP32 into it, and scenario (2) passes bfloat16 into it.

liger_cross_entropy_kernel[(n_rows,)](
    X_ptr=logits_chunk,
    X_stride=logits_chunk.stride(-2),
    Y_ptr=target_chunk,
    Y_stride=target_chunk.stride(-1),  # always 1
    loss_ptr=loss_1d_slice,
    z_loss_ptr=loss_1d_slice,  # dummy ptr, not used
    loss_stride=loss_1d_slice.stride(-1),  # always 1
    n_cols=V,
    n_non_ignore=n_non_ignore,
    ignore_index=ignore_index,
    lse_square_scale=lse_square_scale,
    label_smoothing=label_smoothing,
    reduction=reduction,
    softcap=softcap if softcap is not None else 0.0,
    RETURN_Z_LOSS=0,  # False
    HAS_SOFTCAPPING=True if softcap is not None else False,
    BLOCK_SIZE=BLOCK_SIZE,
    num_warps=32 if not is_hip() else 16,
)

But because inside the liger_cross_entropy_kernel, it uses:

X_block = tl.load(
    X_ptr + X_offsets,
    mask=X_offsets < n_cols,
    other=float("-inf"),
    # Ensure float32 precision for softmax calculation
).cast(tl.float32)

the actual computation time is the same. The only difference is the cast time, which is negligible compared to the actual computation.


The next part that will have different dtype operations is:

grad_logits_chunk = logits_chunk * alpha  # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
  • The grad_logits_chunk = logits_chunk * alpha is an aten::mul operation with [chunk_size * vocab_size] * scalar.

    • In scenario (1) and (3), logits_chunk is FP32, so it becomes (fp32 @ scalar).
    • In scenario (2), logits_chunk is BF16, so it becomes (bf16 @ scalar).
  • For grad_input[start_idx:end_idx] = grad_logits_chunk @ weight, inside PyTorch it is aten::matmul + aten::slice + aten::copy_. Since grad_input is always FP32 , the difference is in the aten::matmul:

    • In scenario (1) and (3), it’s (fp32 @ fp32).
    • In scenario (2), it’s (bf16 @ fp32).

Finally, the addmm step has no difference in scenarios (1), (2), and (3) since we ensure it is FP32 .

So in conclusion, the latency difference between (1) and (2) is:

  1. Element-wise Multiplication Difference:

    • [chunk_size * vocab_size] * scalar
    • Latency difference is (FP32 @ scalar) - (BF16 @ scalar).
  2. Matrix Multiplication Difference:

    • [chunk_size * vocab_size] * [vocab_size * hidden_size]
    • Latency difference is (FP32 @ FP32) - (BF16 @ FP32).
  3. Overall Latency:
    The total latency difference is aggregated across all chunks:

    total_latency_difference = num_chunk * (latency_difference from operations)
    

Confirming the difference
The results can be confirmed by the following experiments:

An experiment setting B=8, T=128, H=4096, V=128256 to simulate the LLaMA3 8B lm head scenario with one forward pass of FLCE.

By comparing the end-to-end latency:

(1) 658077 us
(2) 660219 us
(3) 668929 us

We can see that scenarios (1) and (2) differ by a very small margin.(it can be considered as within the margin of error). And for Scenario (3), because it uses a bias, has more addition operations and is therefore slightly higher than (1) and (2), which makes sense.

Because we know that any difference might come from aten::mul and aten::matmul, let’s first check all operators via a pie chart. The aten::mul only takes around 0.2% of time, which is negligible whether it’s (fp32@scalar) or (bf16@scalar):

image

The mul operation difference is negligible, so let’s move on to the aten::matmul part.

This is (1) matmul:

image

This is (2) matmul:

image

Both scenarios have nearly identical total matmul time, so that means even if the inputs are (bf16@fp32) vs. (fp32@fp32), the latency is the same.

I suspect the reason is that cuBLAS and PyTorch currently use FP32 to accumulate matrix multiplication results and then cast back to BF16, but I am not an expert on this part.

So in conclusion, casting logits_chunk at the beginning or only right before addmm has a negligible latency difference.


Now let’s move on to the memory part.
This part is more trivial.

Let me show the memory snapshot first. The experiment setting follows the previous one.

For scenario (1):

image

  • The peak is 6107.5 MiB.

For scenario (2):

image

  • The peak is 6091.9 MiB.

I find that the difference between (1) and (2) is only 15.6 MiB. Let’s see where the memory pattern differs:

image

This memory snapshot is for scenario (1):

grad_logits_chunk = logits_chunk * alpha  # chunk_size x V

Because logits_chunk is FP32, grad_logits_chunk is FP32 as well, and the shape of those tensors is chunk_size x V. In my experiments, the size is 32 x 128256 x 4 bytes x 2 = 31.4 MiB.

For scenario (2):

image

Because logits_chunk is BF16, grad_logits_chunk is BF16 too, and the shape of those tensors is chunk_size x V. In my experiments, the size is 32 x 128256 x 2 bytes x 2 = 15.7 MiB.

Therefore, scenarios (1) and (2) differ by 31.4 MiB - 15.7 MiB = 15.7 MiB, matching the profiling result of 15.6 MiB.

So in conclusion, the latency only differs in two places (at::mul and at::matmul):

grad_logits_chunk = logits_chunk * alpha  # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight

and the profiling shows the difference between these two is negligible.

For the memory part, the difference in usage is chunk_size * vocab_size * 2 * 2, which is typically only a few MiBs but will scale with chunk size. For example, if B=8 and T=1024, the chunk_size would be 256, so the peak memory difference would become 125.25 MiB, which may affect memory.

From my perspective, I think converting the dtype before passing it into addmm is better since it can reduce some peak memory. However, there is something to note: in scenario (3), the logits and grad_logits are always FP32, just like scenario (1). So if we pick scenario (2), the dtype and memory pattern will differ compared to scenario (3).

Maybe for alignment so that the pattern is the same whether bias=None or bias != None, we could still keep converting at the beginning. This way, no matter what bias is, logits_chunk is always FP32. Because the memory difference is not huge and only scales slowly with batch size or sequence length, I think keeping the conversion at the beginning is an acceptable option.

But if we want to minimize the memory peak, I think we could pick scenario (2), which is converting before addmm, and make a code change after the addition, to achieve minimal memory usage yet keep the same behavior for bias=None or bias!=None. For example:

logits_chunk = _input_chunk @ weight.t()  # chunk_size x V
autocast_dtype = logits_chunk.dtype
if bias is not None:
    logits_chunk = logits_chunk + bias
    logits_chunk = logits_chunk.to(autocast_dtype)

Or perhaps just keep the bias scenario in FP32 and only convert before addmm to minimize memory usage in the bias=None scenario.

I think all three options are acceptable. How do you think, or do you have another idea?

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Dec 29, 2024

Exceptional analysis! I prefer the 2nd option since the main concern before was that the logits_chunk.to(float) allocation might be suprisingly large for some models (#238). For type promotion in the bias!=None scenario, I think keeping it in FP32 is fine because the LM head doesn't use the bias term as you mentioned in #501.

Super solid work! I'll merge it asap after your update. Amazing first contribution in Liger-Kernel. Looking forward to your contributions in future.

@DandinPower
Copy link
Contributor Author

Hi, thank you very much for your careful review and comments! I agree with your opinion since the logics_chunk includes a factor, which is the vocab size, and it can be very large in some models and future models. I have made the changes and committed them; all tests have passed.

@Tcc0403 Tcc0403 merged commit 174b191 into linkedin:main Dec 29, 2024
@austin362667
Copy link
Collaborator

Hey @DandinPower Thank you so much! This work is super impressive.
Have you joined Liger's official slack channel yet? Would love to see you there~

@DandinPower
Copy link
Contributor Author

@austin362667 Thanks for the invitation! I’ve just joined the Slack channel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants