Skip to content

Commit

Permalink
feat(//py/torch_tensorrt/dynamo): Support for BF16
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 14, 2024
1 parent 39b6818 commit c6cf790
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 11 deletions.
5 changes: 3 additions & 2 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kLong, nvinfer1::DataType::kINT64},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
{at::kBool, nvinfer1::DataType::kBOOL},
{at::kBFloat16, nvinfer1::DataType::kBF16}};
return at_trt_type_map;
}

Expand All @@ -307,7 +308,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
{nvinfer1::DataType::kINT64, at::kLong},
{nvinfer1::DataType::kINT8, at::kChar},
{nvinfer1::DataType::kBOOL, at::kBool},
};
{nvinfer1::DataType::kBF16, at::kBFloat16}};
return trt_at_type_map;
}
} // namespace
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int32";
case nvinfer1::DataType::kINT64:
return stream << "Int64";
case nvinfer1::DataType::kBF16:
return stream << "BFloat16";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
Expand Down
21 changes: 14 additions & 7 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class dtype(Enum):
f32 = auto()
f64 = auto()
b = auto()
# TODO: Enable FP8 and BF16
bf16 = auto()
# TODO: Enable FP8
# f8 = auto()
# bf16 = auto()

uint8 = u8
int8 = i8
Expand All @@ -52,8 +52,7 @@ class dtype(Enum):
# float8 = f8
# fp8 = f8

# TODO: Enable when BF16 is enabled
# bfloat16 = bf16
bfloat16 = bf16

@staticmethod
def _is_np_obj(t: Any) -> bool:
Expand Down Expand Up @@ -88,14 +87,16 @@ def _from(
return dtype.f64
elif t == torch.bool:
return dtype.b
elif t == torch.bfloat16:
return dtype.bf16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to Torch-TensorRT supported types ({t}), defaulting to torch_tensorrt.dtype.float"
)
return dtype.float
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: {t}"
f"Provided an unsupported data type as a data type for translation (support: bool, int, long, half, float, bfloat16), got: {t}"
)
elif isinstance(t, trt.DataType):
if t == trt.uint8:
Expand All @@ -112,9 +113,11 @@ def _from(
return dtype.f32
elif t == trt.bool:
return dtype.b
elif t == trt.bf16:
return dtype.bf16
else:
raise TypeError(
f"Provided an unsupported data type as an input data type (support: bool, int32, half, float), got: {t}"
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
)

elif dtype._is_np_obj(t):
Expand All @@ -141,7 +144,7 @@ def _from(
return dtype.float
else:
raise TypeError(
"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: "
"Provided an unsupported data type as an input data type (support: bool, int, long, half, float, bfloat16), got: "
+ str(t)
)

Expand Down Expand Up @@ -215,6 +218,8 @@ def to(
return torch.double
elif self == dtype.b:
return torch.bool
elif self == dtype.bf16:
return torch.bfloat16
elif use_default:
logging.warning(
f"Given dtype that does not have direct mapping to torch ({self}), defaulting to torch.float"
Expand All @@ -238,6 +243,8 @@ def to(
return trt.DataType.FLOAT
elif self == dtype.b:
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif use_default:
return trt.DataType.FLOAT
else:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8, dtype.bf16}


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def _populate_trt_builder_config(
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)

if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)

if self.compilation_settings.sparse_weights:
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
Expand All @@ -22,6 +21,8 @@
)
from torch_tensorrt.fx.types import TRTDataType, TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -545,6 +546,8 @@ def to_numpy(
elif isinstance(value, torch.Tensor):
if value.is_quantized:
value = value.dequantize()
elif value.dtype == torch.bfloat16:
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()

Expand Down
82 changes: 82 additions & 0 deletions tests/py/dynamo/models/test_dtype_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,85 @@ def forward(self, x):
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)


class TestBF16Support(TestCase):
@unittest.skipIf(
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
"Torch-TensorRT Runtime is not available",
)
def test_bf16_cpp(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out

in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)

exp_mod = torch.export.export(mod, (in_tensor,))
trt_mod = torch_tensorrt.dynamo.compile(
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=False,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)

def test_bf16_py(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out

in_tensor = torch.randn((1, 3, 224, 224), device="cuda", dtype=torch.bfloat16)
mod = MyModule().to(torch.device("cuda")).to(torch.bfloat16)

exp_mod = torch.export.export(mod, (in_tensor,))
trt_mod = torch_tensorrt.dynamo.compile(
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Torch outputs and TRT outputs don't match close enough.",
)

0 comments on commit c6cf790

Please sign in to comment.