Skip to content

Commit

Permalink
[cherry-pick] support data_format='NHWC' for prelu channel mode (#38495)
Browse files Browse the repository at this point in the history
* support data_format='NHWC' for prelu channel mode (#37019)

* support data_format='NHWC' for prelu channel mode

* fix prelu weight shape for NHWC of static mode (#38310)
  • Loading branch information
GuoxiaWang authored Dec 29, 2021
1 parent c3cee12 commit c111340
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 127 deletions.
11 changes: 8 additions & 3 deletions paddle/fluid/inference/tensorrt/convert/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter {
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
// Get attrs
std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode"));
std::string data_format = "NCHW";
if (op_desc.HasAttr("data_format")) {
data_format =
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_format"));
}
auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();

Expand All @@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic(
alpha_data, alpha_tensor_temp->numel(), mode);
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
#if IS_TRT_VERSION_GE(7000)
Expand All @@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input,
*alpha_layer_output);
#else
plugin::PReluPlugin* plugin =
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
plugin::PReluPlugin* plugin = new plugin::PReluPlugin(
alpha_data, alpha_tensor_temp->numel(), mode, data_format);
layer = engine_->AddPlugin(&input, input_num, plugin);
#endif
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
}

if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
Expand Down Expand Up @@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
}

if (mode_ == "channel") {
bool channel_last = data_format_ == "NHWC";
operators::math::PreluChannelWiseDirectCUDAFunctor<float>
prelu_channel_wise;
prelu_channel_wise(stream, input, alpha, output, input_dims.d[0],
input_dims.d[1], numel);
input_dims.d[1], channel_last, numel);
} else if (mode_ == "element") {
operators::math::PreluElementWiseDirectCUDAFunctor<float>
prelu_element_wise;
Expand Down
22 changes: 15 additions & 7 deletions paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;

public:
size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(mode_.c_str()) +
SerializedSize(weight_);
SerializedSize(data_format_.c_str()) + SerializedSize(weight_);
}

// TRT will call this func when we need to serialize the configuration of
Expand All @@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT {
serializeBase(buffer);
SerializeValue(&buffer, weight_);
SerializeValue(&buffer, mode_.c_str());
SerializeValue(&buffer, data_format_.c_str());
}

PReluPlugin(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}
Expand All @@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT {
const char* prelu_mode;
DeserializeValue(&serialData, &serialLength, &prelu_mode);
mode_ = std::string(prelu_mode);
const char* prelu_data_format;
DeserializeValue(&serialData, &serialLength, &prelu_data_format);
data_format_ = std::string(prelu_data_format);
}
~PReluPlugin() {}
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;

PReluPlugin* clone() const TRT_NOEXCEPT override {
auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_);
auto* ptr =
new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
Expand Down Expand Up @@ -108,16 +114,17 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator);
class PReluPluginDynamic : public DynamicPluginTensorRT {
public:
PReluPluginDynamic(const float* weight, const int weight_num,
std::string const& mode)
: mode_(mode) {
std::string const& mode, std::string const& data_format)
: mode_(mode), data_format_(data_format) {
weight_.resize(weight_num);
std::copy(weight, weight + weight_num, weight_.data());
}

PReluPluginDynamic(void const* serialData, size_t serialLength);
~PReluPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_,
data_format_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
Expand Down Expand Up @@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
std::vector<float> weight_;
float* p_gpu_weight_;
std::string mode_;
std::string data_format_;
};
#endif

Expand Down
33 changes: 26 additions & 7 deletions paddle/fluid/operators/math/prelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
}

template <typename T>
__global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
__global__ void PReluChannelFirstWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t plane_size, size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
Expand All @@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha,
}
}

template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input, const T *alpha,
T *output, size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}

template <typename T>
__global__ void PReluElementWiseKernel(const T *input, const T *alpha,
T *output, size_t spatial_size,
Expand Down Expand Up @@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output,
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(
gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel) {
PReluChannelWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel,
numel / batch_size / channel, numel);
size_t batch_size, size_t channel, bool channel_last, size_t numel) {
if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(input, alpha, output, channel,
numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
}
}

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/math/prelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output,
size_t batch_size, size_t channel, size_t numel);
size_t batch_size, size_t channel, bool channel_last,
size_t numel);
};

template <typename T>
Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PReluMKLDNNHandler
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* weights,
const std::string& uniq_name, const std::string& mode,
bool is_test = false)
const std::string& data_format, bool is_test = false)
: platform::MKLDNNHandlerT<T, dnnl::prelu_forward, dnnl::prelu_backward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
Expand All @@ -49,8 +49,13 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
}
weights_dims = std::move(new_weights_dims);
}
Expand Down Expand Up @@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");

PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, ctx.InputName("X"), mode, is_test);
alpha, ctx.InputName("X"), mode, data_format,
is_test);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
Expand Down Expand Up @@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel<T> {
auto* alpha = ctx.Input<Tensor>("Alpha");
const bool is_test = ctx.Attr<bool>("is_test");
const auto mode = ctx.Attr<std::string>("mode");
const auto data_format = ctx.Attr<std::string>("data_format");

PReluMKLDNNHandler<T> handler(dev_ctx, onednn_engine, ctx.GetPlace(), x,
alpha, framework::GradVarName("X"), mode);
alpha, framework::GradVarName("X"), mode,
data_format);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto weights_memory_p =
Expand Down
36 changes: 30 additions & 6 deletions paddle/fluid/operators/prelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,40 @@ class PReluOp : public framework::OperatorWithKernel {
"But recevied alpha's size: %d.",
product(ctx->GetInputDim("Alpha"))));
} else if (mode == "channel") {
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"For mode 'channel', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
const std::string data_format_str =
ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC",
true,
platform::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW") {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(ctx->GetInputDim("Alpha")), x_rank - 1,
x_dim[x_rank - 1]));
}

} else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
Expand Down Expand Up @@ -134,6 +155,9 @@ There are modes:
)DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
AddAttr<std::string>("data_format",
"Data format that specifies the layout of input")
.SetDefault("NCHW");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
Expand Down
Loading

0 comments on commit c111340

Please sign in to comment.