-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convolution operator #4042
Convolution operator #4042
Changes from 3 commits
c9d8cb4
40fe0a8
3705de6
c671189
a7c1872
67db9d3
db33ff1
5860150
8219f20
14ae805
fb46345
2340ced
1dd639e
b4ba35c
656f775
7bf1e76
09c65b6
91afa0d
5a4138b
64b0b75
f3669ca
6c0129a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/gemm_conv_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
int outputSize(int input_size, int filter_size, int padding, int stride) { | ||
int output_size = (input_size - filter_size + 2 * padding) / stride + 1; | ||
return output_size; | ||
} | ||
|
||
class Conv2DOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto in = ctx.Input<Tensor>("Input"); | ||
auto filter = ctx.Input<Tensor>("Filter"); | ||
auto out = ctx.Output<Tensor>("Output"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. intput -> input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
PADDLE_ENFORCE_EQ(filter->dims().size(), 4, | ||
"Conv2DOp filter should be 4-D."); | ||
|
||
std::vector<int> strides = Attr<std::vector<int>>("strides"); | ||
std::vector<int> paddings = Attr<std::vector<int>>("paddings"); | ||
auto output_height = | ||
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); | ||
auto output_width = | ||
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); | ||
out->Resize( | ||
{in->dims()[0], filter->dims()[0], output_height, output_width}); | ||
} | ||
}; | ||
|
||
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we do not need to write a Conv2DOpMaker for CudnnConv. |
||
public: | ||
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput( | ||
"Input", | ||
"The input tensor of convolution operator. " | ||
"The format of input tensor is NCHW. Where N is batch size, C is the " | ||
"number of channels, H and W is the height and width of image."); | ||
AddInput( | ||
"Filter", | ||
"The filter tensor of convolution operator." | ||
"The format of the filter tensor is MCHW, where M is the number of " | ||
"output " | ||
"image channels, C is the number of input image channels, H and W is " | ||
"height and width of filter."); | ||
AddOutput("Output", | ||
"The output tensor of convolution operator." | ||
"The format of output tensor is also NCHW."); | ||
AddComment(R"DOC( | ||
The convolution operation calculates the output based on | ||
the input, filter and strides, paddings parameters. | ||
)DOC"); | ||
AddAttr<std::vector<int>>("strides", "strides of convolution operator."); | ||
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator."); | ||
} | ||
}; | ||
|
||
class Conv2DOpGrad : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto in = ctx.Input<Tensor>("Input"); | ||
auto filter = ctx.Input<Tensor>("Filter"); | ||
auto d_in = ctx.Output<Tensor>(framework::GradVarName("Input")); | ||
auto d_filter = ctx.Output<Tensor>(framework::GradVarName("Filter")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
d_in->Resize(in->dims()); | ||
d_filter->Resize(filter->dims()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, | ||
ops::Conv2DOpGrad); | ||
|
||
REGISTER_OP_CPU_KERNEL(conv2d, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current build system requires the filename matches the registered operator name. Maybe rename them both to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); | ||
REGISTER_OP_CPU_KERNEL( | ||
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/gemm_conv_op.h" | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_GPU_KERNEL(conv2d, | ||
ops::GemmConvKernel<paddle::platform::GPUPlace, float>); | ||
REGISTER_OP_GPU_KERNEL( | ||
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/im2col.h" | ||
#include "paddle/operators/math/math_function.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename Place, typename T> | ||
class GemmConvKernel : public framework::OpKernel { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll write the 3D convolution later. Should we distinguish the names? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
const Tensor* input = context.Input<Tensor>("Input"); | ||
Tensor* filter = const_cast<Tensor*>(context.Input<Tensor>("Filter")); | ||
Tensor* output = context.Output<Tensor>("Output"); | ||
output->mutable_data<T>(context.GetPlace()); | ||
|
||
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); | ||
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); | ||
auto filter_dims = filter->dims(); | ||
|
||
int batch_size = input->dims()[0]; | ||
int input_channels = input->dims()[1]; | ||
int filter_height = filter->dims()[filter->dims().size() - 2]; | ||
int filter_width = filter->dims()[filter->dims().size() - 1]; | ||
int output_height = output->dims()[2]; | ||
int output_width = output->dims()[3]; | ||
|
||
paddle::operators::math::Im2ColFunctor< | ||
paddle::operators::math::ColFormat::kCFO, Place, T> | ||
im2col; | ||
framework::DDim col_shape = {input_channels, filter_height, filter_width, | ||
output_height, output_width}; | ||
Tensor col; | ||
col.mutable_data<float>(col_shape, context.GetPlace()); | ||
|
||
auto* device_context = | ||
const_cast<platform::DeviceContext*>(context.device_context_); | ||
|
||
framework::DDim input_shape = {input->dims()[1], input->dims()[2], | ||
input->dims()[3]}; | ||
framework::DDim filter_matrix_shape = { | ||
filter->dims()[0], | ||
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; | ||
framework::DDim col_matrix_shape = { | ||
input_channels * filter_height * filter_width, | ||
output_height * output_width}; | ||
framework::DDim output_matrix_shape = { | ||
output->dims()[1], output->dims()[2] * output->dims()[3]}; | ||
filter->Resize(filter_matrix_shape); | ||
|
||
// convolution operator: im2col + gemm | ||
for (int i = 0; i < batch_size; i++) { | ||
// im2col | ||
Tensor in_slice = input->Slice<T>(i, i + 1); | ||
in_slice.Resize(input_shape); | ||
col.Resize(col_shape); | ||
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], | ||
device_context); | ||
|
||
// gemm | ||
Tensor out_slice = output->Slice<T>(i, i + 1); | ||
out_slice.Resize(output_matrix_shape); | ||
col.Resize(col_matrix_shape); | ||
math::matmul<Place, T>(*filter, false, col, false, T(1.0), &out_slice, | ||
T(0.0), device_context); | ||
} | ||
filter->Resize(filter_dims); | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class GemmConvGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
const Tensor* input = context.Input<Tensor>("Input"); | ||
Tensor* filter = const_cast<Tensor*>(context.Input<Tensor>("Filter")); | ||
const Tensor* output_grad = | ||
context.Input<Tensor>(framework::GradVarName("Output")); | ||
Tensor* input_grad = | ||
context.Output<Tensor>(framework::GradVarName("Input")); | ||
Tensor* filter_grad = | ||
context.Output<Tensor>(framework::GradVarName("Filter")); | ||
input_grad->mutable_data<T>(context.GetPlace()); | ||
filter_grad->mutable_data<T>(context.GetPlace()); | ||
|
||
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); | ||
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); | ||
auto filter_dims = filter->dims(); | ||
|
||
int batch_size = input->dims()[0]; | ||
int input_channels = input->dims()[1]; | ||
int filter_height = filter->dims()[filter->dims().size() - 2]; | ||
int filter_width = filter->dims()[filter->dims().size() - 1]; | ||
int output_height = output_grad->dims()[2]; | ||
int output_width = output_grad->dims()[3]; | ||
|
||
paddle::operators::math::Col2ImFunctor< | ||
paddle::operators::math::ColFormat::kCFO, Place, T> | ||
col2im; | ||
paddle::operators::math::Im2ColFunctor< | ||
paddle::operators::math::ColFormat::kCFO, Place, T> | ||
im2col; | ||
Tensor col; | ||
framework::DDim col_shape = {input_channels, filter_height, filter_width, | ||
output_height, output_width}; | ||
col.mutable_data<float>(col_shape, context.GetPlace()); | ||
|
||
auto* device_context = | ||
const_cast<platform::DeviceContext*>(context.device_context_); | ||
|
||
framework::DDim input_shape = {input->dims()[1], input->dims()[2], | ||
input->dims()[3]}; | ||
framework::DDim filter_matrix_shape = { | ||
filter->dims()[0], | ||
filter->dims()[1] * filter->dims()[2] * filter->dims()[3]}; | ||
framework::DDim col_matrix_shape = { | ||
input_channels * filter_height * filter_width, | ||
output_height * output_width}; | ||
framework::DDim output_matrix_shape = { | ||
output_grad->dims()[1], | ||
output_grad->dims()[2] * output_grad->dims()[3]}; | ||
filter->Resize(filter_matrix_shape); | ||
filter_grad->Resize(filter_matrix_shape); | ||
|
||
auto t1 = framework::EigenVector<T>::Flatten(*filter_grad); | ||
t1.device(context.GetEigenDevice<Place>()) = t1.constant(static_cast<T>(0)); | ||
auto t2 = framework::EigenVector<T>::Flatten(*input_grad); | ||
t2.device(context.GetEigenDevice<Place>()) = t2.constant(static_cast<T>(0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't the gradient be cleared here? The weights entered between different Op may be shared. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If the weights are shared, the framework is responsible for merge the two parts of the gradients. |
||
|
||
// convolution backward input operator: gemm + col2im | ||
// convolution backward weight operator: im2col + gemm | ||
for (int i = 0; i < batch_size; i++) { | ||
// gemm | ||
Tensor out_slice = output_grad->Slice<T>(i, i + 1); | ||
out_slice.Resize(output_matrix_shape); | ||
col.Resize(col_matrix_shape); | ||
math::matmul<Place, T>(*filter, true, out_slice, false, T(1.0), &col, | ||
T(0.0), device_context); | ||
|
||
// col2im | ||
Tensor in_grad_slice = input_grad->Slice<T>(i, i + 1); | ||
in_grad_slice.Resize(input_shape); | ||
col.Resize(col_shape); | ||
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], | ||
paddings[1], device_context); | ||
|
||
// im2col | ||
Tensor in_slice = input->Slice<T>(i, i + 1); | ||
in_slice.Resize(input_shape); | ||
col.Resize(col_shape); | ||
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], | ||
device_context); | ||
|
||
// gemm | ||
col.Resize(col_matrix_shape); | ||
math::matmul<Place, T>(out_slice, false, col, true, T(1.0), filter_grad, | ||
T(1.0), device_context); | ||
} | ||
filter->Resize(filter_dims); | ||
filter_grad->Resize(filter_dims); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is also used in conv3d, pooling2d, pooling3d. Should it be written in one place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be fixed in the next PR. At present, it is not sure where to put this function is better.