Skip to content

Commit

Permalink
fix(//core/conversion/converters): Fix plugin implementation for TRT 7
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Jul 16, 2020
1 parent cff4211 commit 94d6a0f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
3 changes: 2 additions & 1 deletion core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()

deconv->setStrideNd(stride);
deconv->setPaddingNd(padding);
#if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR == 1)
deconv->setDilationNd(dilation);
deconv->setNbGroups(groups);

#endif
new_layer = deconv;
} else {
nvinfer1::IConvolutionLayer* conv;
Expand Down
6 changes: 3 additions & 3 deletions core/conversion/converters/impl/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
} else {
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
}
#else
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
Expand Down Expand Up @@ -185,7 +185,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
} else {
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
}
#else
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
Expand Down Expand Up @@ -217,7 +217,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
} else {
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
}
#else
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
Expand Down
13 changes: 4 additions & 9 deletions core/conversion/converters/impl/plugins/interpolate_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons

cudaStreamWaitEvent(torch_stream.stream(), event, 0);

if (mode == "linear") {
if (mode_ == "linear") {
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
} else if (mode == "bilinear") {
} else if (mode_ == "bilinear") {
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
} else if (mode == "trilinear") {
} else if (mode_ == "trilinear") {
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
} else if (mode == "adaptive_pool2d") {
} else if (mode_ == "adaptive_pool2d") {
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
}

Expand Down Expand Up @@ -212,11 +212,6 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
}

output = output.contiguous();
for (int i = 0; i < util::volume(outputDesc->dims); i++) {
std::cout << ((float*)output.data_ptr())[i] << std::endl;
}

cudaMemcpyAsync(outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);

Expand Down

0 comments on commit 94d6a0f

Please sign in to comment.