Skip to content

Commit 04ea5f8

Browse files
authored
Merge branch 'main' into deprecate-old-apis
2 parents fbb2f2b + c96f2dd commit 04ea5f8

32 files changed

+1086
-201
lines changed

.github/scripts/validate_binaries.sh

100644100755
File mode changed.

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# To run these benchmarks, use the following command:
99
#
10-
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
10+
# torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
1111
#
1212
#######################################################################
1313
import argparse
@@ -24,6 +24,7 @@
2424
all_to_all_single,
2525
all_to_all_single_autograd,
2626
)
27+
from torch.nn import functional as F
2728
from tqdm import tqdm
2829

2930
from benchmarks.utils import profile_fn
@@ -66,33 +67,53 @@ def get_configs() -> List[ExperimentConfig]:
6667
return configs
6768

6869

69-
# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
70-
def default_a2a_dispatch(
70+
def default_a2a_fwd_bwd(
7171
routed_input: torch.Tensor,
72+
labels: torch.Tensor,
7273
output_splits_list: list[int],
7374
input_splits_list: list[int],
7475
device_mesh: DeviceMesh,
7576
):
76-
"""
77-
Default implementation of all-to-all dispatch. Incurs device-to-host sync.
78-
79-
Returns:
80-
routed_input: the local tokens after all-to-all dispatch
81-
input_splits: the input splits for all-to-all dispatch
82-
output_splits: the output splits for all-to-all dispatch
83-
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
84-
"""
85-
# perform all-to-all
8677
routed_input = all_to_all_single_autograd(
8778
routed_input,
8879
output_splits_list,
8980
input_splits_list,
9081
device_mesh.get_group(),
9182
)
9283
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
84+
85+
loss = F.mse_loss(routed_input, labels)
86+
loss.backward()
87+
88+
torch.cuda.synchronize()
9389
return routed_input
9490

9591

92+
def mxfp8_a2a_fwd_bwd(
93+
routed_input: torch.Tensor,
94+
labels: torch.Tensor,
95+
output_splits_list: list[int],
96+
input_splits_list: list[int],
97+
device_mesh: DeviceMesh,
98+
):
99+
routed_input = to_mxfp8_a2a_dequant(
100+
routed_input,
101+
output_splits_list,
102+
input_splits_list,
103+
device_mesh.get_group(),
104+
)
105+
106+
loss = F.mse_loss(routed_input, labels)
107+
loss.backward()
108+
torch.cuda.synchronize()
109+
return routed_input
110+
111+
112+
# Compile target funcs
113+
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
114+
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)
115+
116+
96117
def run_experiment(
97118
config: ExperimentConfig, args: argparse.Namespace
98119
) -> ExperimentResult:
@@ -101,8 +122,9 @@ def run_experiment(
101122
(batch_size * seq_len, dim),
102123
dtype=torch.bfloat16,
103124
device=device,
125+
requires_grad=True,
104126
)
105-
ref_x = x.detach().clone()
127+
ref_x = x.detach().clone().requires_grad_(True)
106128

107129
# Set up device mesh
108130
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
@@ -121,24 +143,27 @@ def warmup(func_no_args):
121143
)
122144
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)
123145

124-
# Compile target funcs
125-
default_a2a_dispatch_c = torch.compile(default_a2a_dispatch)
126-
to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant)
146+
# Generate labels
147+
labels_shape = (sum(output_splits_list), dim)
148+
labels = x.new_ones(*labels_shape)
127149

128150
# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
129151
warmup(
130-
lambda: default_a2a_dispatch_c(
131-
ref_x, output_splits_list, input_splits_list, mesh
152+
lambda: default_a2a_sync_compiled(
153+
ref_x, labels, output_splits_list, input_splits_list, mesh
132154
)
133155
)
134156
start_sec = time.perf_counter()
135-
default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh)
157+
default_a2a_sync_compiled(
158+
ref_x, labels, output_splits_list, input_splits_list, mesh
159+
)
136160
end_sec = time.perf_counter()
137161
bf16_ms = (end_sec - start_sec) * 1e3
138162
if args.profile:
139163
profile_fn(
140-
default_a2a_dispatch_c,
164+
default_a2a_sync_compiled,
141165
ref_x,
166+
labels,
142167
output_splits_list,
143168
input_splits_list,
144169
mesh,
@@ -148,16 +173,19 @@ def warmup(func_no_args):
148173

149174
# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
150175
warmup(
151-
lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
176+
lambda: mxfp8_a2a_sync_compiled(
177+
x, labels, output_splits_list, input_splits_list, mesh
178+
)
152179
)
153180
start_sec = time.perf_counter()
154-
to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
181+
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
155182
end_sec = time.perf_counter()
156183
mxfp8_ms = (end_sec - start_sec) * 1e3
157184
if args.profile:
158185
profile_fn(
159-
to_mxfp8_a2a_dequant_c,
186+
mxfp8_a2a_sync_compiled,
160187
x,
188+
labels,
161189
output_splits_list,
162190
input_splits_list,
163191
mesh,
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Testing compatibility
2+
# We know that torchao .so files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+. (see #2919)
3+
# If the version of torch is not compatible with the version of torchao,
4+
# we expect to skip loading the .so files and a warning should be logged but no error
5+
6+
PREV_TORCH_VERSION = 2.8.0
7+
PREV_TORCHAO_VERSION = 0.13.0
8+
9+
# Function to check torchao import with configurable expectations
10+
check_torchao_import() {
11+
local expect_warning="$1"
12+
local warning_text="$2"
13+
local torch_incompatible="${3:-}"
14+
15+
if [ -n "$torch_incompatible" ]; then
16+
output=$(TORCH_INCOMPATIBLE=1 python -c "import torchao" 2>&1)
17+
else
18+
output=$(python -c "import torchao" 2>&1)
19+
fi
20+
exit_code=$?
21+
22+
if [ $exit_code -ne 0 ]; then
23+
echo "ERROR: Failed to import torchao"
24+
echo "Output: $output"
25+
exit 1
26+
fi
27+
28+
warning_found=false
29+
if [ -n "$warning_text" ] && echo "$output" | grep -i "$warning_text" > /dev/null; then
30+
echo "Output: $output"
31+
warning_found=true
32+
fi
33+
34+
if [ "$expect_warning" != "$warning_found" ]; then
35+
echo echo "FAILURE: expect_warning is $expect_warning but warning_found is $warning_found with message $output"
36+
exit 1
37+
fi
38+
}
39+
40+
## prev torch version, prev torchao version
41+
# Uninstall torch
42+
pip uninstall torch
43+
# Uninstall torchao
44+
pip uninstall torchao
45+
# Install prev compatible version of torch
46+
pip install PREV_TORCH_VERSION
47+
# Installs prev compatible version of torchao
48+
pip install PREV_TORCHAO_VERSION
49+
# hould import successfully without warning
50+
check_torchao_import "false" ""
51+
52+
## current torch, current torchao
53+
# Uninstall torch
54+
pip uninstall torch
55+
# Uninstall torchao
56+
pip uninstall torchao
57+
# Install specific compatible version of torch (nightly 2.9.0dev)
58+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu129
59+
# Build torchao from source
60+
python setup.py develop
61+
# Should import successfully without warning
62+
check_torchao_import "false" ""
63+
## prev torch, torchao from source (do not rebuild), env var = True
64+
# Uninstall torch
65+
pip uninstall torch
66+
# Install incompatible version of torch
67+
pip install torch==PREV_TORCH_VERSION
68+
# Should import with warning because optional env var is set to true
69+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version" "TORCHAO_SKIP_LOADING_SO_FILES=1"
70+
71+
72+
# current torch, prev torchao
73+
# Uninstall torch
74+
pip uninstall torch
75+
# Uninstall torchao
76+
pip uninstall torchao
77+
# Install non-ABI stable torch version
78+
pip install torch==2.9.0
79+
# Installs incompatible torchao
80+
pip install torchao==PREV_TORCHAO_VERSION
81+
# Should import with specific warning
82+
check_torchao_import "true" "Skipping import of cpp extensions due to incompatible torch version"

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,16 @@ def test_narrow_similar_to_vllm(self):
218218
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
219219
)
220220
self._test_narrow_similar_to_vllm(config)
221+
222+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
223+
@pytest.mark.skipif(
224+
not torch_version_at_least("2.8.0"),
225+
reason="torch.compile requires PyTorch 2.8+",
226+
)
227+
def test_nvfp4_quantize_3d_param_similar_to_vllm(self):
228+
config = NVFP4InferenceConfig(
229+
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
230+
use_triton_kernel=False,
231+
use_dynamic_per_tensor_scale=False,
232+
)
233+
self._test_quantize_3d_param_similar_to_vllm(config)

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
(torch.float32, (64, 128), False),
4343
(torch.bfloat16, (128, 256), False),
4444
(torch.bfloat16, (64, 128), True),
45+
(torch.bfloat16, (1, 32, 64), False),
4546
],
4647
)
4748
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -83,14 +84,20 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
8384
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
8485
)
8586

86-
x_nvfp4_t = x_nvfp4.t()
87+
if len(x.shape) == 2:
88+
x_nvfp4_t = x_nvfp4.t()
89+
x_t = x.t()
90+
else:
91+
x_nvfp4_t = x_nvfp4.transpose(-2, -1)
92+
x_t = x.transpose(-2, -1)
93+
8794
x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
88-
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
95+
assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0)
8996

90-
assert x.t().shape == x_reconstructed_t.shape, (
97+
assert x_t.shape == x_reconstructed_t.shape, (
9198
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
9299
)
93-
assert x.t().dtype == x_reconstructed_t.dtype, (
100+
assert x_t.dtype == x_reconstructed_t.dtype, (
94101
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
95102
)
96103

@@ -103,6 +110,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
103110
(16, 32),
104111
(64, 128),
105112
(384, 128),
113+
(1, 32, 64),
106114
],
107115
)
108116
@pytest.mark.skipif(
@@ -115,8 +123,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
115123
that the _is_swizzled_scales flag is set correctly.
116124
"""
117125

118-
M, K = shape
119-
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
126+
data = torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
120127

121128
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
122129
assert tensor._is_swizzled_scales == is_swizzled_scales
@@ -536,36 +543,43 @@ def test_nvfp4_to_copy():
536543
@pytest.mark.parametrize("use_triton_kernel", [False, True])
537544
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
538545
@pytest.mark.parametrize(
539-
"mk",
546+
"shape",
540547
(
541548
(128, 64),
542549
(128 + 16, 64),
543550
(128, 64 + 16),
544551
(128 + 16, 64 + 16),
552+
(1, 128, 64),
545553
),
546554
)
547555
def test_scale_shape_matches_qdata(
548-
transpose, use_triton_kernel, is_swizzled_scales, mk
556+
transpose, use_triton_kernel, is_swizzled_scales, shape
549557
):
550558
if use_triton_kernel and not is_sm_at_least_100():
551559
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
552560
if use_triton_kernel and not is_swizzled_scales:
553561
pytest.skip("triton kernel requires swizzled scales")
554562

555-
M, K = mk
556-
557563
block_size = 16
558564

559-
x_hp = torch.randn(M, K, device="cuda")
565+
x_hp = torch.randn(*shape, device="cuda")
560566
x = NVFP4Tensor.to_nvfp4(
561567
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
562568
)
563569

564-
m_dim, k_dim = 0, 1
565-
if transpose:
566-
x_hp = x_hp.t()
567-
x = x.t()
568-
m_dim, k_dim = 1, 0
570+
if len(shape) == 2:
571+
m_dim, k_dim = 0, 1
572+
if transpose:
573+
x_hp = x_hp.t()
574+
x = x.t()
575+
m_dim, k_dim = 1, 0
576+
else:
577+
assert len(shape) == 3, "unsupported"
578+
m_dim, k_dim = 1, 2
579+
if transpose:
580+
x_hp = x_hp.transpose(-2, -1)
581+
x = x.transpose(-2, -1)
582+
m_dim, k_dim = 2, 1
569583

570584
orig_m = x_hp.shape[m_dim]
571585
expected_padded_m = orig_m
@@ -587,3 +601,17 @@ def test_scale_shape_matches_qdata(
587601
assert expected_padded_k == actual_padded_k, (
588602
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
589603
)
604+
605+
606+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
607+
@pytest.mark.skipif(
608+
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
609+
)
610+
@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1)))
611+
@pytest.mark.parametrize("is_swizzled_scales", [True, False])
612+
def test_3d_transpose(dims, is_swizzled_scales):
613+
x_hp = torch.randn(2, 128, 256, device="cuda")
614+
x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales)
615+
x_hp_t = x_hp.transpose(dims[0], dims[1])
616+
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
617+
assert x_hp_t.shape == x_nvfp4_t.shape

test/prototype/test_awq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from torchao.prototype.awq import AWQConfig, AWQStep
1818
from torchao.quantization import Int4WeightOnlyConfig, quantize_
19-
from torchao.utils import _is_fbgemm_genai_gpu_available, torch_version_at_least
19+
from torchao.utils import _is_fbgemm_gpu_genai_available, torch_version_at_least
2020

2121

2222
class ToyLinearModel(torch.nn.Module):
@@ -46,7 +46,7 @@ def forward(self, x):
4646
devices = ["cpu"]
4747
if (
4848
torch.cuda.is_available()
49-
and _is_fbgemm_genai_gpu_available()
49+
and _is_fbgemm_gpu_genai_available()
5050
and torch_version_at_least("2.6.0")
5151
):
5252
devices.append("cuda")

0 commit comments

Comments
 (0)