Skip to content

Commit d407246

Browse files
authored
Improve QAT nvfp4 numerics (#3050)
* Support NVFP4 dynamic per tensor scale **Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in #2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned] * Improve QAT nvfp4 numerics **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. Details TBD. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimick the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned] * Update base for Update on "Improve QAT nvfp4 numerics" **Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
1 parent cbd3adb commit d407246

File tree

7 files changed

+197
-86
lines changed

7 files changed

+197
-86
lines changed

test/quantization/test_qat.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20882087

20892088
self._test_quantize_api_against_ptq(
20902089
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2090+
target_prepare_sqnr=float("inf"),
20922091
target_convert_sqnr=float("inf"),
20932092
)
20942093

2094+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
20952095
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
20962096
@parametrize("use_per_tensor_scale", [True, False])
20972097
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20982098
"""
20992099
Test QAT with `NVFP4FakeQuantizeConfig`.
21002100
"""
2101+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
21012102
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21022103

21032104
torch.manual_seed(self.SEED)
21042105
m = M().cuda()
21052106
baseline_model = copy.deepcopy(m)
2107+
quantize_(
2108+
baseline_model,
2109+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2110+
)
21062111
qat_config = QATConfig(
21072112
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21082113
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
@@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162121
out = m(*x)
21172122
baseline_out = baseline_model(*x)
21182123
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2124+
self.assertGreaterEqual(sqnr, float("inf"))
21202125

21212126
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222127
@unittest.skipIf(

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -771,37 +771,13 @@ def nvfp4_quantize(
771771
AssertionError: If input dtype is not supported, tensor size is not
772772
divisible by block_size, tensor is not contiguous, or block_size != 16
773773
"""
774-
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
775-
776-
777-
class _Float8Round(torch.autograd.Function):
778-
"""
779-
Cast a tensor to float8 and back to float32 with backward STE.
780-
"""
781-
782-
@staticmethod
783-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
784-
return x.to(torch.float8_e4m3fn).to(torch.float32)
785-
786-
@staticmethod
787-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
788-
return gy
789-
790-
791-
def _nvfp4_quantize(
792-
data_hp: torch.Tensor,
793-
block_size: int = 16,
794-
per_tensor_scale: Optional[torch.Tensor] = None,
795-
skip_dtype_cast_and_packing: bool = False,
796-
) -> tuple[torch.Tensor, torch.Tensor]:
797774
assert data_hp.dtype in (torch.bfloat16, torch.float), (
798775
f"{data_hp.dtype} not supported"
799776
)
800777
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
801778
assert data_hp.is_contiguous(), "Only support contiguous data for now"
802779
assert block_size == 16, "NVFP4 requires block_size=16"
803780

804-
orig_dtype = data_hp.dtype
805781
orig_shape = data_hp.shape
806782
# Convert to float32 early for consistent precision with Triton implementation
807783
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -813,8 +789,10 @@ def _nvfp4_quantize(
813789
out_scales = None
814790
if per_tensor_scale is None:
815791
# We are doing single level scaling
816-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
817-
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
792+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
793+
torch.float8_e4m3fn
794+
)
795+
block_scale_fp32 = block_scale_fp8.to(torch.float32)
818796
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
819797
out_scales = block_scale_fp8
820798
else:
@@ -826,8 +804,8 @@ def _nvfp4_quantize(
826804
scaled_block_scales = block_scale_fp32 / per_tensor_scale
827805
scaled_block_scales_fp8 = torch.clamp(
828806
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
829-
)
830-
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
807+
).to(torch.float8_e4m3fn)
808+
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
831809
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
832810
# To apply to data
833811
total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -836,11 +814,8 @@ def _nvfp4_quantize(
836814

837815
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
838816
data_scaled = data_scaled.view(orig_shape)
839-
if skip_dtype_cast_and_packing:
840-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
841-
else:
842-
data_lp = f32_to_f4_unpacked(data_scaled)
843-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
844-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
845-
data_lp = pack_uint4(data_lp)
846-
return out_scales.to(torch.float8_e4m3fn), data_lp
817+
data_lp = f32_to_f4_unpacked(data_scaled)
818+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
819+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
820+
data_lp = pack_uint4(data_lp)
821+
return out_scales, data_lp

torchao/prototype/qat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from .nvfp4 import (
55
NVFP4FakeQuantizeConfig,
6-
NVFP4FakeQuantizer,
6+
NVFP4FakeQuantizedLinear,
77
)
88

99
__all__ = [
1010
"NVFP4FakeQuantizeConfig",
11-
"NVFP4FakeQuantizer",
11+
"NVFP4FakeQuantizedLinear",
1212
]

torchao/prototype/qat/nvfp4.py

Lines changed: 152 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from dataclasses import dataclass
2+
from typing import Optional
23

34
import torch
45

56
from torchao.prototype.mx_formats.nvfp4_tensor import (
6-
_nvfp4_quantize,
7+
NVFP4Tensor,
8+
_addmm_nvfp4_dispatch,
79
per_tensor_amax_to_scale,
810
)
9-
from torchao.quantization.qat import (
10-
FakeQuantizeConfigBase,
11-
FakeQuantizerBase,
12-
)
11+
from torchao.quantization.qat import FakeQuantizeConfigBase
1312

1413

1514
@dataclass
@@ -23,47 +22,166 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
2322
Args:
2423
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
2524
after the initial fp8 (e4m3) block-wise scaling (default True)
25+
use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
26+
use_triton_kernel (bool): Whether to use triton kernels during fake quantization
2627
"""
2728

2829
use_per_tensor_scale: bool = True
30+
use_swizzled_scales: bool = False
31+
use_triton_kernel: bool = False
32+
33+
34+
# TODO: support emulation on non-Blackwell GPUs
35+
class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function):
36+
"""
37+
Autograd function for NVFP4 quantization + addmm in low precision during forward,
38+
and fake quantization in high precision during backward.
39+
"""
40+
41+
@staticmethod
42+
def forward(
43+
ctx,
44+
_input: torch.Tensor,
45+
weight: torch.Tensor,
46+
bias: Optional[torch.Tensor],
47+
activation_config: NVFP4FakeQuantizeConfig,
48+
weight_config: NVFP4FakeQuantizeConfig,
49+
) -> torch.Tensor:
50+
# quantize input activations
51+
if activation_config.use_per_tensor_scale:
52+
tensor_amax = torch.max(torch.abs(_input))
53+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
54+
else:
55+
per_tensor_scale = None
56+
_input = NVFP4Tensor.to_nvfp4(
57+
_input,
58+
per_tensor_scale=per_tensor_scale,
59+
is_swizzled_scales=activation_config.use_swizzled_scales,
60+
use_triton_kernel=activation_config.use_triton_kernel,
61+
)
62+
63+
# quantize weights
64+
if weight_config.use_per_tensor_scale:
65+
tensor_amax = torch.max(torch.abs(weight))
66+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
67+
else:
68+
per_tensor_scale = None
69+
weight = NVFP4Tensor.to_nvfp4(
70+
weight,
71+
per_tensor_scale=per_tensor_scale,
72+
is_swizzled_scales=weight_config.use_swizzled_scales,
73+
use_triton_kernel=False,
74+
)
2975

76+
# Follow `NVFP4InferenceConfig`, always use traditional construction
77+
# for weights and set `use_triton_kernel` afterwards
78+
weight.use_triton_kernel = weight_config.use_triton_kernel
3079

31-
class NVFP4FakeQuantizer(FakeQuantizerBase):
80+
ctx.save_for_backward(_input, weight)
81+
82+
return _addmm_nvfp4_dispatch(
83+
_input,
84+
weight.t(),
85+
None, # aten_op, not used
86+
bias,
87+
)
88+
89+
@staticmethod
90+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
91+
_input, weight = ctx.saved_tensors
92+
assert isinstance(_input, NVFP4Tensor)
93+
assert isinstance(weight, NVFP4Tensor)
94+
_input = _input.to_dtype(_input._orig_dtype)
95+
weight = weight.to_dtype(weight._orig_dtype)
96+
grad_input = torch.mm(grad_output, weight)
97+
grad_weight = torch.mm(grad_output.t(), _input)
98+
return grad_input, grad_weight, None, None, None
99+
100+
101+
class NVFP4FakeQuantizedLinear(torch.nn.Linear):
32102
"""
33-
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
103+
Linear module for fake quantized NVFP4 weights and/or activations.
104+
105+
The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
106+
in lower precision exactly, while the backward pass uses dequantize
107+
(fake quantized) values in high precision.
108+
109+
Currently this is only applicable on Blackwell and future generations.
110+
See https://github.com/pytorch/ao/issues/3102 for more details.
111+
112+
Example usage::
113+
114+
from torchao.quantization import quantize_
115+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
116+
117+
base_config = NVFP4InferenceConfig()
118+
quantize_(model, QATConfig(base_config, step="prepare"))
119+
# Model contains `NVFP4FakeQuantizedLinear` now
120+
121+
train_loop(model)
122+
quantize_(model, QATConfig(base_config, step="convert"))
123+
# Model contains `nn.Linear` with `NVFP4Tensor` weights now
34124
"""
35125

36-
def __init__(self, config: NVFP4FakeQuantizeConfig):
37-
super().__init__()
38-
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39-
self.config = config
126+
def __init__(
127+
self,
128+
in_features: int,
129+
out_features: int,
130+
bias: bool = False,
131+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
132+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
133+
*args,
134+
**kwargs,
135+
):
136+
super().__init__(
137+
in_features,
138+
out_features,
139+
bias,
140+
*args,
141+
**kwargs,
142+
)
143+
if weight_config is None:
144+
raise ValueError("Must specify `weight_config`")
145+
if activation_config is None:
146+
raise ValueError("Weight only NVFP4 QAT not supported yet")
147+
self.activation_config = activation_config
148+
self.weight_config = weight_config
40149

41150
def forward(self, x: torch.Tensor) -> torch.Tensor:
42-
block_size = 16
43-
original_shape = x.shape
44151
if x.dim() == 3:
152+
batch_size = x.shape[0]
45153
x = x.view(-1, x.shape[-1])
46-
if self.config.use_per_tensor_scale:
47-
tensor_amax = torch.max(torch.abs(x))
48-
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
49154
else:
50-
per_tensor_scale = None
155+
batch_size = None
156+
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
157+
x, self.weight, self.bias, self.activation_config, self.weight_config
158+
)
159+
assert fq.dtype == x.dtype
160+
if batch_size is not None:
161+
return fq.view(batch_size, -1, fq.shape[-1])
162+
else:
163+
return fq
51164

52-
# quantize
53-
scale, q = _nvfp4_quantize(
54-
x,
55-
block_size=block_size,
56-
per_tensor_scale=per_tensor_scale,
57-
skip_dtype_cast_and_packing=True,
165+
@classmethod
166+
def from_linear(
167+
cls,
168+
mod: torch.nn.Linear,
169+
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
170+
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
171+
):
172+
new_linear = NVFP4FakeQuantizedLinear(
173+
mod.in_features,
174+
mod.out_features,
175+
mod.bias is not None,
176+
activation_config=activation_config,
177+
weight_config=weight_config,
178+
device=mod.weight.device,
179+
dtype=mod.weight.dtype,
58180
)
59-
if self.config.use_per_tensor_scale:
60-
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
62-
assert scale.dtype == torch.float32
63-
64-
# dequantize
65-
M, K = q.shape[0], q.shape[1]
66-
q = q.view(M, K // block_size, block_size)
67-
scale = scale.view(M, K // block_size, 1)
68-
dq = q * scale
69-
return dq.view(original_shape).to(x.dtype)
181+
# In distributed training, the model may be instantiated
182+
# on the meta device, in which case there is no need to
183+
# copy the weights, and doing so will result in an error
184+
if mod.weight.device != torch.device("meta"):
185+
new_linear.weight = mod.weight
186+
new_linear.bias = mod.bias
187+
return new_linear

torchao/quantization/qat/api.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,24 @@ def _qat_config_transform(
208208
act_config = config.activation_config
209209
weight_config = config.weight_config
210210
if isinstance(module, torch.nn.Linear):
211-
return FakeQuantizedLinear.from_linear(module, act_config, weight_config)
211+
# TODO: rewrite this using a registration API so
212+
# specific quantization schemes do not leak here
213+
from torchao.prototype.qat import (
214+
NVFP4FakeQuantizeConfig,
215+
NVFP4FakeQuantizedLinear,
216+
)
217+
218+
if isinstance(weight_config, NVFP4FakeQuantizeConfig):
219+
assert act_config is None or isinstance(
220+
act_config, NVFP4FakeQuantizeConfig
221+
)
222+
return NVFP4FakeQuantizedLinear.from_linear(
223+
module, act_config, weight_config
224+
)
225+
else:
226+
return FakeQuantizedLinear.from_linear(
227+
module, act_config, weight_config
228+
)
212229
elif isinstance(module, torch.nn.Embedding):
213230
if act_config is not None:
214231
raise ValueError(

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,16 @@ def _infer_fake_quantize_configs(
444444
elif isinstance(base_config, NVFP4InferenceConfig):
445445
if NVFP4MMConfig.DYNAMIC:
446446
act_config = NVFP4FakeQuantizeConfig(
447-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
447+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
448+
use_swizzled_scales=False,
449+
use_triton_kernel=False,
448450
)
449451
else:
450452
act_config = None
451453
weight_config = NVFP4FakeQuantizeConfig(
452-
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
454+
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
455+
use_swizzled_scales=True,
456+
use_triton_kernel=base_config.use_triton_kernel,
453457
)
454458
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
455459
assert base_config.version >= 2, "Only version 2+ is supported"

0 commit comments

Comments
 (0)