From 88b6ba17ca41dbd84710f4c0bb18554557b7dcc2 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Tue, 19 Nov 2024 13:54:44 -0800 Subject: [PATCH 1/4] Add hardware check to fp8 quant --- torchao/quantization/quant_api.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0e6ebdc7e0..4f6908162b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -82,6 +82,7 @@ from .utils import _get_per_token_block_size logger = logging.getLogger(__name__) +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) __all__ = [ "swap_conv2d_1x1_to_linear", @@ -939,6 +940,9 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_cuda_8_9 + ), "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -993,6 +997,9 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_cuda_8_9 + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) From 2423c1d1c8e491315951f80885d1d24e3e315b22 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 21 Nov 2024 16:05:58 -0800 Subject: [PATCH 2/4] MI300 check Summary: Test Plan: Tested on AMD Instinct MI300X Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/quant_api.py | 5 +++-- torchao/utils.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4f6908162b..491b7b2286 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -51,6 +51,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_MI300, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -941,7 +942,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_cuda_8_9 + is_cuda_8_9 or is_MI300 ), "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -998,7 +999,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_cuda_8_9 + is_cuda_8_9 or is_MI300 ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index e474824135..3839943ccf 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -31,6 +31,7 @@ "TORCH_VERSION_AFTER_2_3", "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", + "is_MI300", ] @@ -586,6 +587,16 @@ def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version +def is_MI300(): + if torch.cuda.is_available() and torch.version.hip: + mxArchName = ["gfx940", "gfx941", "gfx942"] + archName = torch.cuda.get_device_properties().gcnArchName + for arch in mxArchName: + if arch in archName: + return True + return False + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") From b5945d3ef3a6d07ceca32bf0f682fccf1a51cdce Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 22 Nov 2024 12:38:09 -0800 Subject: [PATCH 3/4] Test fixes --- test/dtypes/test_affine_quantized_float.py | 3 +++ torchao/quantization/quant_api.py | 6 +++--- torchao/utils.py | 8 ++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1e2ff29796..74c130dc5e 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -134,10 +134,12 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -145,6 +147,7 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_unsupported_granularity(self): class UnsupportedGranularity: pass diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 491b7b2286..a8e9fa493b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -942,8 +942,8 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_cuda_8_9 or is_MI300 - ), "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above" + is_cuda_8_9 or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -999,7 +999,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_cuda_8_9 or is_MI300 + is_cuda_8_9 or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index 3839943ccf..d2f2d11d7a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -597,6 +597,14 @@ def is_MI300(): return False +def is_cuda_8_9(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") From 436d3aae0a0dd705f203ebde8738482495ac8127 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 25 Nov 2024 11:05:02 -0800 Subject: [PATCH 4/4] Granularoty validation --- torchao/quantization/quant_api.py | 28 ++++++++++++++++++++++------ torchao/utils.py | 12 +++++++++++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a8e9fa493b..c730ec9046 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -52,6 +52,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_MI300, + is_sm_89, + is_sm_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -83,7 +85,6 @@ from .utils import _get_per_token_block_size logger = logging.getLogger(__name__) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) __all__ = [ "swap_conv2d_1x1_to_linear", @@ -829,10 +830,11 @@ def _normalize_granularity( Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] ], ) -> Tuple[_fp8_granularities, _fp8_granularities]: + processed_granularity = None if granularity is None: - return (PerTensor(), PerTensor()) + processed_granularity = (PerTensor(), PerTensor()) elif isinstance(granularity, (PerTensor, PerRow)): - return (granularity, granularity) + processed_granularity = (granularity, granularity) elif isinstance(granularity, tuple) and len(granularity) == 2: if not ( isinstance(granularity[0], (PerTensor, PerRow)) @@ -845,11 +847,25 @@ def _normalize_granularity( raise ValueError( f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." ) - return granularity + processed_granularity = granularity else: raise ValueError( f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." ) + # Validate granularity with supported Hardware + for _granularity in processed_granularity: + if isinstance(_granularity, PerTensor): + assert ( + is_sm_89() or is_MI300() + ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" + elif isinstance(_granularity, PerRow): + assert ( + is_sm_90() or is_MI300() + ), "PerRow quantization only works for CUDA>=9.0 and MI300+" + else: + raise ValueError(f"Invalid granularity type: {_granularity}") + + return processed_granularity def _input_activation_quant_func_fp8( @@ -942,7 +958,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_cuda_8_9 or is_MI300() + is_sm_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -999,7 +1015,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_cuda_8_9 or is_MI300() + is_sm_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index d2f2d11d7a..2813f0b0b4 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -32,6 +32,8 @@ "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", "is_MI300", + "is_sm_89", + "is_sm_90", ] @@ -597,7 +599,7 @@ def is_MI300(): return False -def is_cuda_8_9(): +def is_sm_89(): return ( torch.cuda.is_available() and torch.version.cuda @@ -605,6 +607,14 @@ def is_cuda_8_9(): ) +def is_sm_90(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (9, 0) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")