Skip to content

Commit

Permalink
Improve batch_norm fp16 accuracy (#70)
Browse files Browse the repository at this point in the history
# Description

Use float types for compile-time calculations around batch_norm. Improves fp16 accuracy relative to pytorch.
Fixes # (issue)

## Type of change

Please delete options that are not relevant and/or add your own.

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update

# Checklist:

- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
  • Loading branch information
mfeliz-cruise committed Jan 9, 2023
1 parent dc570e4 commit 0257b21
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
23 changes: 18 additions & 5 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,28 @@ void _batch_norm(
const torch::Tensor& mean,
const torch::Tensor& var,
const float eps) {
auto scale = gamma / torch::sqrt(var + eps);
auto bias = beta - mean * scale;
auto orig_dtype = var.dtype();
// perform compile-time weight calculations in float to improve accuracy
// resulting weights will be embedded as the original dtype
auto calculation_gamma = gamma;
auto calculation_beta = beta;
auto calculation_mean = mean;
auto calculation_var = var;
if (orig_dtype == torch::kHalf) {
calculation_gamma = calculation_gamma.to(torch::kFloat);
calculation_beta = calculation_beta.to(torch::kFloat);
calculation_mean = calculation_mean.to(torch::kFloat);
calculation_var = calculation_var.to(torch::kFloat);
}
auto scale = calculation_gamma / torch::sqrt(calculation_var + eps);
auto bias = calculation_beta - calculation_mean * scale;
LOG_DEBUG("_batch_norm Tensor Scale : " << scale.sizes());
LOG_DEBUG("_batch_norm Tensor bias : " << bias.sizes());

auto scale_weights = Weights(ctx, scale);
auto bias_weights = Weights(ctx, bias);
auto scale_weights = Weights(ctx, scale.to(orig_dtype));
auto bias_weights = Weights(ctx, bias.to(orig_dtype));

auto power = Weights(ctx, at::ones_like(scale));
auto power = Weights(ctx, at::ones_like(scale).to(orig_dtype));
auto bn =
ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, power.data, 1);
bn->setName(util::node_info(n).c_str());
Expand Down
30 changes: 30 additions & 0 deletions tests/core/conversion/converters/test_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,33 @@ TEST(Converters, ATenBatchNormShouldUnpackConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenBatchNormHalfConvertsCorrectly) {
const auto graph = R"IR(
graph(%input : Tensor, %running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0), %running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0)):
%5 : bool = prim::Constant[value=0]()
%4 : float = prim::Constant[value=0.01]()
%3 : float = prim::Constant[value=0.001]()
%2 : bool = prim::Constant[value=1]()
%8 : Tensor = aten::batch_norm(%input, %running_var, %running_mean, %running_mean, %running_var, %5, %4, %3, %2)
return (%8))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto in = at::randn({2, 32, 5, 5}, {at::kCUDA}).to(at::kHalf);
auto mean = at::ones({32}, {at::kCUDA}).to(at::kHalf);
auto var = at::zeros({32}, {at::kCUDA}).to(at::kHalf);

auto trt_in = at::clone(in);
auto trt_mean = at::clone(mean);
auto trt_var = at::clone(var);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {mean, var});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {trt_mean, trt_var});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}, {nvinfer1::DataType::kHALF});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

0 comments on commit 0257b21

Please sign in to comment.