You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Support NVFP4 dynamic per tensor scale
**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.
In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.
**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig
m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")
start = time.time()
for _ in range(1000):
m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")
start = time.time()
for _ in range(1000):
m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```
On a single B200:
```
No per_tensor_scale = 1.2855589389801025 seconds
With per_tensor_scale = 1.3009123802185059 seconds
```
[ghstack-poisoned]
* Improve QAT nvfp4 numerics
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.
**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation
Wikitext:
- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline
```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
| | |none | 0|word_perplexity|↓ |9.418|± | N/A|
==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.3681|± | N/A|
# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.2281|± | N/A|
```
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation
Wikitext:
- With this PR, QAT nvfp4 quantized model achieved 15% lower
perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
same as the quantized baseline
```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
| | |none | 0|word_perplexity|↓ |9.418|± | N/A|
==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.3681|± | N/A|
# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.2281|± | N/A|
```
[ghstack-poisoned]
* Update base for Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation
Wikitext:
- With this PR, QAT nvfp4 quantized model achieved 15% lower
perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
same as the quantized baseline
```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
| | |none | 0|word_perplexity|↓ |9.418|± | N/A|
==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.3681|± | N/A|
# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.2281|± | N/A|
```
[ghstack-poisoned]
0 commit comments