-
Notifications
You must be signed in to change notification settings - Fork 19
Conversation
Summary: We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with #300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Test Plan: ``` // all tests pass ./test_everything.sh // also run all benchmarks and verify correctness ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: We are standardizing on `Float8Linear` as the only float8 linear object: 1. the stack ending with #300 moved all of the functionality of `Float8DynamicLinear` to `Float8Linear`. The default settings of `Float8Linear` are to use dynamic scaling. 2. this PR deletes `Float8DynamicLinear` from the codebase and patches the relevant callsites in fbsource. Test Plan: ``` // all tests pass ./test_everything.sh // also run all benchmarks and verify correctness ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8ab483377124960fec2f133c0e27fbbaab204528 Pull Request resolved: #304
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -14,7 +14,7 @@ | |||
import torch.multiprocessing as mp | |||
import torch.nn as nn | |||
import torch.utils.benchmark as benchmark | |||
from float8_experimental.float8_linear import Float8Linear | |||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How useful is this benchmark in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't used it recently
# example: "x:del,w:del,dldy:dyn" | ||
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}" | ||
# example: "x_del_w_del_dldy_dyn" | ||
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change out of curiosity? I think the prior version might be a little more readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have reverted this. Will follow-up in a future PR if that's ok, to make landing this PR easier.
@@ -48,8 +45,12 @@ def _test_compile_base( | |||
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) | |||
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) | |||
|
|||
m_fp8 = get_float8_linear( | |||
linear_type, m_ref, emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY | |||
m_fp8 = Float8Linear.from_float( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling 'swap_..' on nn.Linear module returns a model out of place. I think its fine either way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we can make the tests use that if we want in a future PR.
"scaling_type_dL_dY": TensorScalingType.DYNAMIC, | ||
} | ||
# For now, just use Float8Linear with dynamic scaling, which is the | ||
# same behavior as Float8Linear. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float8Dynamic ? But also its probably to to just say, only supports dynamic scaling for all 3 tensors, x, w, dl_dY
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, let me fix in a future PR to speed up landing this, since this is a minor point.
@@ -29,8 +27,7 @@ def check_parity_no_mp( | |||
for param in model.parameters(): | |||
dist.all_reduce(param.grad) | |||
param.grad.div_(dist.get_world_size()) | |||
if module_cls is Float8Linear: | |||
sync_float8_amax_and_scale_history(model) | |||
# TODO(future): add amax syncing once delayed scaling is supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this just an unused code path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
return swap_linear_with_float8_linear(module, Float8Linear, **kwargs) | ||
else: | ||
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just remove this since this is the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed in principle, but ideally that would be a separate PR since it's only tangentially related
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Burn it with fire!🔥
This pull request has been merged in 8e9623a. |
Summary: Addressing a couple of nits that slipped in #304 * more defaults to dynamic * undo repr change * fix comment Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary:
We are standardizing on
Float8Linear
as the only float8 linear object:[9/x]: make dynamic scaling default in Float8Linear #300 moved
all of the functionality of
Float8DynamicLinear
toFloat8Linear
.The default settings of
Float8Linear
are to use dynamic scaling.Float8DynamicLinear
from the codebase and patchesthe relevant callsites in fbsource.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59342767