-
Notifications
You must be signed in to change notification settings - Fork 20
add unit tests for FSDP2 + torch.compile(transformer block) #321
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -301,7 +303,7 @@ def __tensor_flatten__(self): | |||
], | |||
{ | |||
"mm_config": self._mm_config, | |||
"is_amax_initialized": is_amax_initialized, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre-commit run --all-files
complains about undefined is_amax_initialized
in trunk. fixing it so I can commit without bypassing linter
test/test_fsdp2/test_fsdp2_common.py
Outdated
@@ -46,7 +47,10 @@ def check_parity_no_mp( | |||
): | |||
precompute_float8_dynamic_scale_for_fsdp(model) | |||
|
|||
test_cls.assertEqual(losses[0], losses[1]) | |||
if compile_transformer_block: | |||
torch.testing.assert_close(losses[0], losses[1], atol=9.5e-2, rtol=9.5e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
borrowed 9.5e-2 from test_compile.py
https://github.com/pytorch-labs/float8_experimental/blob/main/test/test_compile.py#L62
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems kind of high 🤔 I wonder how this value was determined. Can we instead compare the ref as also compiling each transformer block (but without FSDP applied)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will try to switch the ref model to Float8Linear + torch.compiled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after applying torch.compile to ref_model, we can achieve atol/rtol=1e-4
. I can dig more as follow ups if we want to reach higher numeric parity like 1e-5
@@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) | |||
scales = torch.split(scale_tensor, 1) # Replicate | |||
for scale, float8_linear in zip(scales, float8_linears): | |||
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor | |||
float8_linear.weight._local_tensor._precomputed_scale = ( | |||
scale._local_tensor.squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure tensor is like tensor(4674.8633)
instead of tensor([[4674.8633]]
otherwise torch.compile errors out in gurads, torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check the traces? I want to make sure there is no CPU sync point introduced from making this tensor a scalar tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean when the scalar is used later downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know the reasoning for why the current behavior is not supported with compile? This might not scale long term as we add other scaling granularities like rowwise or blockwise.
TL;DR: this is more like a bug when I implement precompute_float8_dynamic_scale_for_fsdp
for the 1st iteration, self._precomputed_scale
is None
and thus we calcuclate scale through cast_to_float8_e4m3_dynamic
(code) , where scale are in tensor(4674.8633)
. Dynamo generates a guard assersion on tensor(4674.8633).size()
and tensor(4674.8633).stride()
, so it expect same input shapes in 2nd iteration
for the 2nd iteration after precompute_float8_dynamic_scale_for_fsdp
, we have self._precomputed_scale=tensor([[4674.8633]])
because I only called torch.split(scale_tensor, 1)
without .squeeze
. Guard assersion find out .size()
and .stride()
changed and throw out the error
does it make sense to say this is a bug in user code, instead of a misfunction in dynamo ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean when the scalar is used later downstream.
ah, I see. I should be looking for cudaStreamSynchronize
, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be like cudaDeviceSynchronize
if I understand correctly (but basically you would see the CPU thread blocked).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean when the scalar is used later downstream.
_precomputed_scale
will be used inside fsdp_pre_all_gather
when calling following code https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/fsdp_utils.py#L167
float8_tensor = Float8Tensor.to_float8(
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
mm_config=self._mm_config,
)
I annotated the function with record_function("Float8Tensor.to_float8")
. Here are the snapshots for cpu thread and cuda stream
in both cases, I do not see cudaStreamSynchronize
and cuda stream stays ahead of cpu thread
any worries ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! should be fine
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -4,8 +4,6 @@ | |||
# This source code is licensed under the BSD 3-Clause license found in the | |||
# LICENSE file in the root directory of this source tree. | |||
|
|||
from typing import Any, Optional, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix linter from the trunk
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
fully_shard(submodule) | ||
for layer_id, transformer_block in module.layers.named_children(): | ||
if compile_transformer_block: | ||
transformer_block = torch.compile(transformer_block, dynamic=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?
optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us and keep the more advanced "how to apply compile" logic localized to torchtitan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is compiling the transformer block instead of the entire model related to this issue, or are we just trying to match torchtitan behavior?
This is just trying to match torchtitan's behavior. The .squeeze
is needed regardless of compiling transformer blocks or compiling whole model.
optionally, if possible, would be good to compile the whole model here instead as long as that can catch the issues relevant to us
I want to check at PR time that float8_experimental are compatiable with torchtitan (thus compiling transformer block)
for float8_experimental, I am with you it's good to also cover compiling full model.
For FSDP2, it should work. For FSDP+TP, I remember there is some problem in to compile full model. Will see if I can follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thanks for fixing this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch!
@weifengpy merged this pull request in 7f0d6bb. |
fixed my bug in float8_experimental. now we can torch.compile transfromer blocks with FSDP float8 all-gather pytorch-labs/float8_experimental#321 local test: `CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.enable_fsdp_float8_all_gather --training.precompute_float8_dynamic_scale_for_fsdp --training.compile` profiler traces: I can see compiled region in cpu thread and float8 malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream <img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM" src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">
TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block).
there is a mismatch in float8 scale so dynamo guards assersion failed
torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())
cast_to_float8_e4m3_dynamic
(code). scale is a scalar tensor, egtensor(4674.8633)
precompute_float8_dynamic_scale
, but scale is NOT a scalar tensor, egtensor([[4674.8633]]
.squeeze
to make sure scales are always scalar tensors, and dynamo guards assersion always hold trueadded unit test so we can catch the isssue at PR time
TODO: add fp8 + torch.compile to CI in torchtitan