Skip to content

Commit

Permalink
Merge pull request #1450 from mfeliz-cruise/michael.feliz/improve_bat…
Browse files Browse the repository at this point in the history
…chnorm_fp16_accuracy

Improve batch_norm fp16 accuracy
  • Loading branch information
peri044 authored Feb 13, 2023
2 parents 3467511 + 0b0666a commit 3bfc052
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
23 changes: 18 additions & 5 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ void _batch_norm(
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
auto orig_dtype = var.dtype();
// perform compile-time weight calculations in float to improve accuracy
// resulting weights will be embedded as the original dtype
auto calculation_gamma = gamma;
auto calculation_beta = beta;
auto calculation_mean = mean;
auto calculation_var = var;
if (orig_dtype == torch::kHalf) {
calculation_gamma = calculation_gamma.to(torch::kFloat);
calculation_beta = calculation_beta.to(torch::kFloat);
calculation_mean = calculation_mean.to(torch::kFloat);
calculation_var = calculation_var.to(torch::kFloat);
}
auto scale = calculation_gamma / torch::sqrt(calculation_var + eps);
auto bias = calculation_beta - calculation_mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);
auto scale_weights = Weights(ctx, scale.to(orig_dtype));
auto bias_weights = Weights(ctx, bias.to(orig_dtype));

auto power = Weights(ctx, at::ones_like(scale));
auto power = Weights(ctx, at::ones_like(scale).to(orig_dtype));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
Expand Down
30 changes: 30 additions & 0 deletions tests/core/conversion/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNormHalfConvertsCorrectly) {
const auto graph = R"IR(
graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
%5 : bool = prim::Constant[value=0]()
%4 : float = prim::Constant[value=0.01]()
%3 : float = prim::Constant[value=0.001]()
%2 : bool = prim::Constant[value=1]()
%8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randn({2, 32, 5, 5}, {at::kCUDA}).to(at::kHalf);
auto mean = at::ones({32}, {at::kCUDA}).to(at::kHalf);
auto var = at::zeros({32}, {at::kCUDA}).to(at::kHalf);

auto trt_in = at::clone(in);
auto trt_mean = at::clone(mean);
auto trt_var = at::clone(var);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_mean, trt_var});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, {nvinfer1::DataType::kHALF});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-2));
}

0 comments on commit 3bfc052

Please sign in to comment.