Skip to content

Commit

Permalink
move float8 callsites to torchao.float8 (#492)
Browse files Browse the repository at this point in the history
Summary:

The `float8_experimental` repository moved to `torchao.float8` in
pytorch/ao#551

This PR updates `torchtitan` to use float8 from the new location.

Test Plan:

```
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Jul 30, 2024
1 parent 9cf4b2f commit b069f70
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ jobs:
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
4 changes: 2 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def __init__(self):
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'float8_experimental' which can be found
here: https://github.com/pytorch-labs/float8_experimental
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
Expand Down
16 changes: 8 additions & 8 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# [Note] Getting the 'float8_experimental' package:
# This script requires the 'float8_experimental' package to function correctly.
# [Note] Getting the 'torchao' package:
# This script requires the 'torchao' package to function correctly.
# Please ensure you have this package installed from the appropriate repository.
# You can obtain it from https://github.com/pytorch-labs/float8_experimental.
# Either clone and run `pip install .` or run `pip install git+https://github.com/pytorch-labs/float8_experimental.git`
# You can obtain it from https://github.com/pytorch/ao by following the
# installation instructions.

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
Expand Down Expand Up @@ -48,7 +48,7 @@ def maybe_build_fp8_linear(
)
return
try:
from float8_experimental import (
from torchao.float8 import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
Expand Down Expand Up @@ -83,7 +83,7 @@ def maybe_build_fp8_linear(
)
except ImportError as exc:
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
"torchao is not installed. Please install it to use fp8 linear layers."
) from exc


Expand All @@ -102,7 +102,7 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)

Expand All @@ -121,7 +121,7 @@ def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobCo
):
return

from float8_experimental import sync_float8_amax_and_scale_history
from torchao.float8 import sync_float8_amax_and_scale_history

# TODO(future): see if precalculating the modules to sync over is going to
# meaningfully help performance
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def get_tp_parallel_strategy_for_transformer_block(
# TODO(future PR): once float8 configuration supports delayed
# scaling, add a check here to enforce supported float8 all-gather
# configurations
from float8_experimental.float8_tensor_parallel import (
# TODO(future PR): add the items below to __init__.py of torchao.float8,
# and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
Expand Down

0 comments on commit b069f70

Please sign in to comment.