diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 3ca5780603..13acc546f5 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -295,7 +295,8 @@ const std::unordered_map& get_at_trt_type_ma {at::kLong, nvinfer1::DataType::kINT64}, {at::kChar, nvinfer1::DataType::kINT8}, {at::kByte, nvinfer1::DataType::kINT8}, - {at::kBool, nvinfer1::DataType::kBOOL}}; + {at::kBool, nvinfer1::DataType::kBOOL}, + {at::kBFloat16, nvinfer1::DataType::kBF16}}; return at_trt_type_map; } @@ -307,7 +308,7 @@ const std::unordered_map& get_trt_at_type_ma {nvinfer1::DataType::kINT64, at::kLong}, {nvinfer1::DataType::kINT8, at::kChar}, {nvinfer1::DataType::kBOOL, at::kBool}, - }; + {nvinfer1::DataType::kBF16, at::kBFloat16}}; return trt_at_type_map; } } // namespace diff --git a/core/util/trt_util.h b/core/util/trt_util.h index da6653bef3..f3df533d8b 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -55,6 +55,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& return stream << "Int32"; case nvinfer1::DataType::kINT64: return stream << "Int64"; + case nvinfer1::DataType::kBF16: + return stream << "BFloat16"; case nvinfer1::DataType::kBOOL: return stream << "Bool"; default: diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 8eeb55cf36..32587a38f2 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -24,9 +24,9 @@ class dtype(Enum): f32 = auto() f64 = auto() b = auto() - # TODO: Enable FP8 and BF16 + bf16 = auto() + # TODO: Enable FP8 # f8 = auto() - # bf16 = auto() uint8 = u8 int8 = i8 @@ -52,8 +52,7 @@ class dtype(Enum): # float8 = f8 # fp8 = f8 - # TODO: Enable when BF16 is enabled - # bfloat16 = bf16 + bfloat16 = bf16 @staticmethod def _is_np_obj(t: Any) -> bool: @@ -88,6 +87,8 @@ def _from( return dtype.f64 elif t == torch.bool: return dtype.b + elif t == torch.bfloat16: + return dtype.bf16 elif use_default: logging.warning( f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float" @@ -95,7 +96,7 @@ def _from( return dtype.float else: raise TypeError( - f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}" + f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}" ) elif isinstance(t, trt.DataType): if t == trt.uint8: @@ -112,9 +113,11 @@ def _from( return dtype.f32 elif t == trt.bool: return dtype.b + elif t == trt.bf16: + return dtype.bf16 else: raise TypeError( - f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}" + f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}" ) elif dtype._is_np_obj(t): @@ -141,7 +144,7 @@ def _from( return dtype.float else: raise TypeError( - "Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: " + "Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: " + str(t) ) @@ -215,6 +218,8 @@ def to( return torch.double elif self == dtype.b: return torch.bool + elif self == dtype.bf16: + return torch.bfloat16 elif use_default: logging.warning( f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float" @@ -238,6 +243,8 @@ def to( return trt.DataType.FLOAT elif self == dtype.b: return trt.DataType.BOOL + elif self == dtype.bf16: + return trt.DataType.BF16 elif use_default: return trt.DataType.FLOAT else: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 97430137c0..7931dc865c 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,7 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8} +SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16} def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 59d2c5d6c0..56e7e069c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -242,6 +242,9 @@ def _populate_trt_builder_config( if dtype.int8 in self.compilation_settings.enabled_precisions: builder_config.set_flag(trt.BuilderFlag.INT8) + if dtype.bfloat16 in self.compilation_settings.enabled_precisions: + builder_config.set_flag(trt.BuilderFlag.BF16) + if self.compilation_settings.sparse_weights: builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index a24a203cab..1c03f1d924 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np -import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch import SymBool, SymFloat, SymInt @@ -22,6 +21,8 @@ ) from torch_tensorrt.fx.types import TRTDataType, TRTTensor +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -545,6 +546,9 @@ def to_numpy( elif isinstance(value, torch.Tensor): if value.is_quantized: value = value.dequantize() + elif value.dtype == torch.bfloat16: + # TODO: Remove when numpy has a BF16 type + value = value.to(torch.float) output = value.cpu().detach().contiguous().numpy() diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index e88c85de75..cc576715ff 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -176,3 +176,85 @@ def forward(self, x): DECIMALS_OF_AGREEMENT, msg=f"Torch outputs and TRT outputs don't match close enough.", ) + + +class TestBF16Support(TestCase): + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_bf16_cpp(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16) + mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16) + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + enabled_precisions={torch.float, torch.bfloat16, torch.half}, + min_block_size=1, + use_python_runtime=False, + ) + + torch_model_results = mod(in_tensor) + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + ) + + def test_bf16_py(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16) + mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16) + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + enabled_precisions={torch.float, torch.bfloat16, torch.half}, + min_block_size=1, + use_python_runtime=True, + ) + + torch_model_results = mod(in_tensor) + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + )