diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1e2ff29796..74c130dc5e 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -134,10 +134,12 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -145,6 +147,7 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") def test_unsupported_granularity(self): class UnsupportedGranularity: pass diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0e6ebdc7e0..c730ec9046 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -51,6 +51,9 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_MI300, + is_sm_89, + is_sm_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -827,10 +830,11 @@ def _normalize_granularity( Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] ], ) -> Tuple[_fp8_granularities, _fp8_granularities]: + processed_granularity = None if granularity is None: - return (PerTensor(), PerTensor()) + processed_granularity = (PerTensor(), PerTensor()) elif isinstance(granularity, (PerTensor, PerRow)): - return (granularity, granularity) + processed_granularity = (granularity, granularity) elif isinstance(granularity, tuple) and len(granularity) == 2: if not ( isinstance(granularity[0], (PerTensor, PerRow)) @@ -843,11 +847,25 @@ def _normalize_granularity( raise ValueError( f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." ) - return granularity + processed_granularity = granularity else: raise ValueError( f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported." ) + # Validate granularity with supported Hardware + for _granularity in processed_granularity: + if isinstance(_granularity, PerTensor): + assert ( + is_sm_89() or is_MI300() + ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" + elif isinstance(_granularity, PerRow): + assert ( + is_sm_90() or is_MI300() + ), "PerRow quantization only works for CUDA>=9.0 and MI300+" + else: + raise ValueError(f"Invalid granularity type: {_granularity}") + + return processed_granularity def _input_activation_quant_func_fp8( @@ -939,6 +957,9 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_sm_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -993,6 +1014,9 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + assert ( + is_sm_89() or is_MI300() + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index e474824135..2813f0b0b4 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -31,6 +31,9 @@ "TORCH_VERSION_AFTER_2_3", "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", + "is_MI300", + "is_sm_89", + "is_sm_90", ] @@ -586,6 +589,32 @@ def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version +def is_MI300(): + if torch.cuda.is_available() and torch.version.hip: + mxArchName = ["gfx940", "gfx941", "gfx942"] + archName = torch.cuda.get_device_properties().gcnArchName + for arch in mxArchName: + if arch in archName: + return True + return False + + +def is_sm_89(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + + +def is_sm_90(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (9, 0) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")