Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hardware check to fp8 quant #1314

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,20 @@ def test_fp8_linear_variants(
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass
Expand Down
8 changes: 8 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_MI300,
)

from .autoquant import AutoQuantizableLinearWeight, autoquant
Expand Down Expand Up @@ -82,6 +83,7 @@
from .utils import _get_per_token_block_size

logger = logging.getLogger(__name__)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"swap_conv2d_1x1_to_linear",
Expand Down Expand Up @@ -939,6 +941,9 @@ def float8_dynamic_activation_float8_weight(
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.

"""
assert (
is_cuda_8_9 or is_MI300()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if granularity is PerTensor then it is sm89 if it is PerRow then it is currenlty sm90 or higher

), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

Expand Down Expand Up @@ -993,6 +998,9 @@ def float8_static_activation_float8_weight(
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""
assert (
is_cuda_8_9 or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)

Expand Down
19 changes: 19 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
"is_MI300",
]


Expand Down Expand Up @@ -586,6 +587,24 @@ def _torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version


def is_MI300():
if torch.cuda.is_available() and torch.version.hip:
mxArchName = ["gfx940", "gfx941", "gfx942"]
archName = torch.cuda.get_device_properties().gcnArchName
for arch in mxArchName:
if arch in archName:
return True
return False


def is_cuda_8_9():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)


TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
Expand Down
Loading