Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(//py/torch_tensorrt/dynamo): Support for BF16 #2833

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kLong, nvinfer1::DataType::kINT64},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
{at::kBool, nvinfer1::DataType::kBOOL},
{at::kBFloat16, nvinfer1::DataType::kBF16}};
return at_trt_type_map;
}

Expand All @@ -307,7 +308,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
{nvinfer1::DataType::kINT64, at::kLong},
{nvinfer1::DataType::kINT8, at::kChar},
{nvinfer1::DataType::kBOOL, at::kBool},
};
{nvinfer1::DataType::kBF16, at::kBFloat16}};
return trt_at_type_map;
}
} // namespace
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int32";
case nvinfer1::DataType::kINT64:
return stream << "Int64";
case nvinfer1::DataType::kBF16:
return stream << "BFloat16";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
Expand Down
21 changes: 14 additions & 7 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class dtype(Enum):
f32 = auto()
f64 = auto()
b = auto()
# TODO: Enable FP8 and BF16
bf16 = auto()
# TODO: Enable FP8
# f8 = auto()
# bf16 = auto()

uint8 = u8
int8 = i8
Expand All @@ -52,8 +52,7 @@ class dtype(Enum):
# float8 = f8
# fp8 = f8

# TODO: Enable when BF16 is enabled
# bfloat16 = bf16
bfloat16 = bf16

@staticmethod
def _is_np_obj(t: Any) -> bool:
Expand Down Expand Up @@ -88,14 +87,16 @@ def _from(
return dtype.f64
elif t == torch.bool:
return dtype.b
elif t == torch.bfloat16:
return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}"
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
)
elif isinstance(t, trt.DataType):
if t == trt.uint8:
Expand All @@ -112,9 +113,11 @@ def _from(
return dtype.f32
elif t == trt.bool:
return dtype.b
elif t == trt.bf16:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensorrt module doesn't have bf16 attribute.

Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorrt
>>> tensorrt.__version__
'10.0.1'
>>> hasattr(tensorrt, "bf16")
False
>>> hasattr(tensorrt, "bfloat16")
True

return dtype.bf16
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}"
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
)

elif dtype._is_np_obj(t):
Expand All @@ -141,7 +144,7 @@ def _from(
return dtype.float
else:
raise TypeError(
"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: "
"Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: "
+ str(t)
)

Expand Down Expand Up @@ -215,6 +218,8 @@ def to(
return torch.double
elif self == dtype.b:
return torch.bool
elif self == dtype.bf16:
return torch.bfloat16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float"
Expand All @@ -238,6 +243,8 @@ def to(
return trt.DataType.FLOAT
elif self == dtype.b:
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif use_default:
return trt.DataType.FLOAT
else:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16}


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def _populate_trt_builder_config(
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)

if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)

if self.compilation_settings.sparse_weights:
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
Expand All @@ -22,6 +21,8 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -545,6 +546,9 @@ def to_numpy(
elif isinstance(value, torch.Tensor):
if value.is_quantized:
value = value.dequantize()
elif value.dtype == torch.bfloat16:
# TODO: Remove when numpy has a BF16 type
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()

Expand Down
82 changes: 82 additions & 0 deletions tests/py/dynamo/models/test_dtype_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,85 @@ def forward(self, x):
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)


class TestBF16Support(TestCase):
@unittest.skipIf(
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
"Torch-TensorRT Runtime is not available",
)
def test_bf16_cpp(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out

in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)

exp_mod = torch.export.export(mod, (in_tensor,))
trt_mod = torch_tensorrt.dynamo.compile(
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=False,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)

def test_bf16_py(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out

in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)

exp_mod = torch.export.export(mod, (in_tensor,))
trt_mod = torch_tensorrt.dynamo.compile(
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)
Loading