Skip to content

Commit

Permalink
[Kernel] Remove scaled_fp8_quant kernel padding footgun (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored and kylesayrs committed Aug 17, 2024
1 parent cba8aa9 commit 2d43e96
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))

# Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,
Expand Down
24 changes: 13 additions & 11 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -317,15 +317,15 @@ def scaled_fp8_quant(
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Expand All @@ -334,16 +334,16 @@ def scaled_fp8_quant(
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
if batch_dim_padding:
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
output = torch.empty(shape,
device=input.device,
dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape = input.shape
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)

if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
Expand All @@ -352,6 +352,8 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert (scale.numel() == 1 or num_token_padding is None)
torch.ops._C.static_scaled_fp8_quant(output, input, scale)

return output, scale
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def apply_fp8_linear(
qinput, x_scale = ops.scaled_fp8_quant(
input,
input_scale,
batch_dim_padding=17,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)

per_tensor_weights = (weight_scale.numel() == 1)
Expand Down Expand Up @@ -177,8 +177,9 @@ def apply_fp8_linear(
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32)
# Unpad (undo batch_dim_padding)
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])

# DQ
# C = sw * sx * (X * W) + bias
Expand Down

0 comments on commit 2d43e96

Please sign in to comment.