Skip to content

Commit 9d99c55

Browse files
committed
enable select for NVFP4Tensor
Summary: This is useful for vLLM 2d -> 3d MoE weight surgery Test Plan: unit tests: ``` pytest test/prototype/mx_formats/ -s ``` Also, after this PR we can run a Qwen 1.5 MoE model quantized with nvfp4 in vLLM, with vllm-project/vllm#25480 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: db79143 ghstack-comment-id: 3361104836 Pull-Request: #3117
1 parent 7542737 commit 9d99c55

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,25 @@ def nvfp4_view_op(func, types, args, kwargs):
634634
)
635635

636636

637+
@implements([aten.select.int])
638+
def nvfp4_select(func, types, args, kwargs):
639+
old, dim, index = args
640+
assert dim == 0, f"NVFP4Tensor aten.select.int with {dim=} is not yet supported"
641+
assert len(old.qdata.shape) == len(old._scale_e4m3.shape), "unsupported"
642+
new = old.__class__(
643+
old.qdata[index],
644+
old._scale_e4m3[index],
645+
old._block_size,
646+
old._orig_dtype,
647+
old._per_tensor_scale,
648+
old._act_per_tensor_scale,
649+
old._is_swizzled_scales,
650+
old.use_triton_kernel,
651+
old.act_quant_kwargs,
652+
)
653+
return return_and_correct_aliasing(func, args, kwargs, new)
654+
655+
637656
def _addmm_nvfp4_dispatch(
638657
a: NVFP4Tensor, b: NVFP4Tensor, aten_op, bias: Optional[torch.Tensor] = None
639658
) -> torch.Tensor:

torchao/testing/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,8 +626,9 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
626626
)
627627

628628
def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
629-
# this happens when vLLM loads empty MoE weights and quantizes
630-
# them
629+
# this happens when vLLM loads empty MoE weights, quantizes
630+
# them, and stitches 2d params from the checkpoint into a 3d param
631+
# in memory
631632

632633
dtype = torch.bfloat16
633634
with torch.device("meta"):
@@ -636,6 +637,7 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
636637
torch.randn(60, 2816, 2048, device="cuda", dtype=dtype)
637638
)
638639
quantize_(l, config)
640+
_w_slice = l.weight[0]
639641

640642

641643
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)

0 commit comments

Comments
 (0)