-
Notifications
You must be signed in to change notification settings - Fork 278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Dtype Mismatch in torch.addmm within ops/fused_linear_cross_entropy.py in AMP training. #502
Conversation
…x dtype mismatch bug in fused_linear_cross_entropy.
@@ -59,6 +59,10 @@ def fused_linear_cross_entropy_forward( | |||
logits_chunk = _input_chunk @ weight.t() # chunk_size x V | |||
if bias is not None: | |||
logits_chunk = logits_chunk + bias | |||
if logits_chunk.dtype != weight.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is dtype comparison necessary? According to aten source code, I think .to(dtype) already does a dtype check implicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right! After reading the source code you provided, conducting some small experiments, and profiling the behavior of PyTorch, it is confirmed that it performs a dtype check implicitly and bypasses the casting operation if the dtype is already the same.
Here is an example of a bfloat16 tensor attempting .to(dtype=torch.float32)
:
And here is an example of a float32 tensor attempting .to(dtype=torch.float32)
:
I will fix this by removing the unnecessary comparison.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Tcc0403, thanks for the review! I have merged the latest main branch and made a new commit to remove the unnecessary dtype
comparison. I also re-ran make test
, make checkstyle
, and make test-convergence
, and all tests have passed as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious how it would affect the speed and memory if we move type conversion into addmm()
. https://github.com/linkedin/Liger-Kernel/pull/502/files#diff-d7a61f0e3b11dccec3ffab797033897bbc73d69abff5538f6ff11da850e38f63R120
if grad_weight is not None:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t().to(_input_chunk.dtype), # here
mat2=_input_chunk,
out=grad_weight,
alpha=alpha,
beta=1.0,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks interesting! I can conduct some experiments to profile the speed and memory.
Hi @Tcc0403,
Let’s start with the latency part. TL;DR, there is only a negligible difference, and it is even hard to test because the margin of error might be larger than the actual differences. Now let me start analyzing why there is only a very small difference. The first difference is in liger_cross_entropy_kernel[(n_rows,)](
X_ptr=logits_chunk,
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
) But because inside the X_block = tl.load(
X_ptr + X_offsets,
mask=X_offsets < n_cols,
other=float("-inf"),
# Ensure float32 precision for softmax calculation
).cast(tl.float32) the actual computation time is the same. The only difference is the cast time, which is negligible compared to the actual computation. The next part that will have different grad_logits_chunk = logits_chunk * alpha # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
Finally, the So in conclusion, the latency difference between (1) and (2) is:
Confirming the difference An experiment setting By comparing the end-to-end latency: (1) We can see that scenarios (1) and (2) differ by a very small margin.(it can be considered as within the margin of error). And for Scenario (3), because it uses a bias, has more addition operations and is therefore slightly higher than (1) and (2), which makes sense. Because we know that any difference might come from The This is (1) This is (2) Both scenarios have nearly identical total I suspect the reason is that cuBLAS and PyTorch currently use FP32 to accumulate matrix multiplication results and then cast back to BF16, but I am not an expert on this part. So in conclusion, casting Now let’s move on to the memory part. Let me show the memory snapshot first. The experiment setting follows the previous one. For scenario (1):
For scenario (2):
I find that the difference between (1) and (2) is only 15.6 MiB. Let’s see where the memory pattern differs: This memory snapshot is for scenario (1): grad_logits_chunk = logits_chunk * alpha # chunk_size x V Because For scenario (2): Because Therefore, scenarios (1) and (2) differ by So in conclusion, the latency only differs in two places ( grad_logits_chunk = logits_chunk * alpha # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight and the profiling shows the difference between these two is negligible. For the memory part, the difference in usage is From my perspective, I think converting the dtype before passing it into Maybe for alignment so that the pattern is the same whether But if we want to minimize the memory peak, I think we could pick scenario (2), which is converting before logits_chunk = _input_chunk @ weight.t() # chunk_size x V
autocast_dtype = logits_chunk.dtype
if bias is not None:
logits_chunk = logits_chunk + bias
logits_chunk = logits_chunk.to(autocast_dtype) Or perhaps just keep the bias scenario in FP32 and only convert before I think all three options are acceptable. How do you think, or do you have another idea? |
Exceptional analysis! I prefer the 2nd option since the main concern before was that the Super solid work! I'll merge it asap after your update. Amazing first contribution in Liger-Kernel. Looking forward to your contributions in future. |
…ther intermediate values
Hi, thank you very much for your careful review and comments! I agree with your opinion since the |
Hey @DandinPower Thank you so much! This work is super impressive. |
@austin362667 Thanks for the invitation! I’ve just joined the Slack channel. |
Summary
This PR addresses a
dtype
mismatch error that I encountered while using PyTorch AMP to train a Llama3 model. After reviewing previous discussions, such as closed issue #305 and PR #318, conducting my own tests, and performing a complete analysis of the problem, I found that there is still a possibility of encountering adtype
mismatch if the bias isNone
during FLCE computation. The detailed observation and analysis of the issue can be found in issue #501.This PR aims to:
dtype
is used, without affecting the behavior in other scenarios.Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence