Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

delete Float8DynamicLinear #304

Closed
wants to merge 1 commit into from
Closed

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 3, 2024

Stack from ghstack (oldest at bottom):

Summary:

We are standardizing on Float8Linear as the only float8 linear object:

  1. the stack ending with
    [9/x]: make dynamic scaling default in Float8Linear #300 moved
    all of the functionality of Float8DynamicLinear to Float8Linear.
    The default settings of Float8Linear are to use dynamic scaling.
  2. this PR deletes Float8DynamicLinear from the codebase and patches
    the relevant callsites in fbsource.

Test Plan:

// all tests pass
./test_everything.sh

// also run all benchmarks and verify correctness

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59342767

Summary:

We are standardizing on `Float8Linear` as the only float8 linear object:
1. the stack ending with
   #300 moved
   all of the functionality of `Float8DynamicLinear` to `Float8Linear`.
   The default settings of `Float8Linear` are to use dynamic scaling.
2. this PR deletes `Float8DynamicLinear` from the codebase and patches
   the relevant callsites in fbsource.

Test Plan:

```
// all tests pass
./test_everything.sh

// also run all benchmarks and verify correctness
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 3, 2024
Summary:

We are standardizing on `Float8Linear` as the only float8 linear object:
1. the stack ending with
   #300 moved
   all of the functionality of `Float8DynamicLinear` to `Float8Linear`.
   The default settings of `Float8Linear` are to use dynamic scaling.
2. this PR deletes `Float8DynamicLinear` from the codebase and patches
   the relevant callsites in fbsource.

Test Plan:

```
// all tests pass
./test_everything.sh

// also run all benchmarks and verify correctness
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8ab483377124960fec2f133c0e27fbbaab204528
Pull Request resolved: #304
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 3, 2024
@vkuzo vkuzo requested review from bdhirsh, drisspg and weifengpy July 3, 2024 19:20
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 3, 2024

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -14,7 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How useful is this benchmark in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't used it recently

# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
# example: "x_del_w_del_dldy_dyn"
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change out of curiosity? I think the prior version might be a little more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have reverted this. Will follow-up in a future PR if that's ok, to make landing this PR easier.

@@ -48,8 +45,12 @@ def _test_compile_base(
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)

m_fp8 = get_float8_linear(
linear_type, m_ref, emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY
m_fp8 = Float8Linear.from_float(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling 'swap_..' on nn.Linear module returns a model out of place. I think its fine either way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we can make the tests use that if we want in a future PR.

"scaling_type_dL_dY": TensorScalingType.DYNAMIC,
}
# For now, just use Float8Linear with dynamic scaling, which is the
# same behavior as Float8Linear.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float8Dynamic ? But also its probably to to just say, only supports dynamic scaling for all 3 tensors, x, w, dl_dY

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, let me fix in a future PR to speed up landing this, since this is a minor point.

@@ -29,8 +27,7 @@ def check_parity_no_mp(
for param in model.parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())
if module_cls is Float8Linear:
sync_float8_amax_and_scale_history(model)
# TODO(future): add amax syncing once delayed scaling is supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this just an unused code path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return swap_linear_with_float8_linear(module, Float8Linear, **kwargs)
else:
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just remove this since this is the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed in principle, but ideally that would be a separate PR since it's only tangentially related

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Burn it with fire!🔥

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 8e9623a.

vkuzo added a commit that referenced this pull request Jul 8, 2024
Summary:

Addressing a couple of nits that slipped in
#304

* more defaults to dynamic
* undo repr change
* fix comment

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 8, 2024
Summary:

Addressing a couple of nits that slipped in
#304

* more defaults to dynamic
* undo repr change
* fix comment

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 49448ecbf3ad15087783f97dcdda278fe4f42d41
Pull Request resolved: #308
facebook-github-bot pushed a commit that referenced this pull request Jul 9, 2024
Summary:
Pull Request resolved: #308

Addressing a couple of nits that slipped in
#304

* more defaults to dynamic
* undo repr change
* fix comment

Reviewed By: drisspg

Differential Revision: D59521233

fbshipit-source-id: 5f69855cc2d19c6057a230b0963185c4396dcd99
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants