Skip to content

Commit

Permalink
Merge pull request #1214 from sguada/global_pooling
Browse files Browse the repository at this point in the history
Add global_pooling to cover the whole input by setting kernel size to bottom size
  • Loading branch information
shelhamer committed Oct 3, 2014
2 parents aee2efb + 760ffaa commit 576931e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ class PoolingLayer : public Layer<Dtype> {
int channels_;
int height_, width_;
int pooled_height_, pooled_width_;
bool global_pooling_;
Blob<Dtype> rand_idx_;
Blob<int> max_idx_;
};
Expand Down
32 changes: 26 additions & 6 deletions src/caffe/layers/pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ template <typename Dtype>
void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingParameter pool_param = this->layer_param_.pooling_param();
CHECK(!pool_param.has_kernel_size() !=
if (pool_param.global_pooling()) {
CHECK(!(pool_param.has_kernel_size() ||
pool_param.has_kernel_h() || pool_param.has_kernel_w()))
<< "With Global_pooling: true Filter size cannot specified";
} else {
CHECK(!pool_param.has_kernel_size() !=
!(pool_param.has_kernel_h() && pool_param.has_kernel_w()))
<< "Filter size is kernel_size OR kernel_h and kernel_w; not both";
CHECK(pool_param.has_kernel_size() ||
CHECK(pool_param.has_kernel_size() ||
(pool_param.has_kernel_h() && pool_param.has_kernel_w()))
<< "For non-square filters both kernel_h and kernel_w are required.";
}
CHECK((!pool_param.has_pad() && pool_param.has_pad_h()
&& pool_param.has_pad_w())
|| (!pool_param.has_pad_h() && !pool_param.has_pad_w()))
Expand All @@ -31,11 +37,17 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
&& pool_param.has_stride_w())
|| (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
<< "Stride is stride OR stride_h and stride_w are required.";
if (pool_param.has_kernel_size()) {
kernel_h_ = kernel_w_ = pool_param.kernel_size();
global_pooling_ = pool_param.global_pooling();
if (global_pooling_) {
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
} else {
kernel_h_ = pool_param.kernel_h();
kernel_w_ = pool_param.kernel_w();
if (pool_param.has_kernel_size()) {
kernel_h_ = kernel_w_ = pool_param.kernel_size();
} else {
kernel_h_ = pool_param.kernel_h();
kernel_w_ = pool_param.kernel_w();
}
}
CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
Expand All @@ -51,6 +63,10 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
stride_h_ = pool_param.stride_h();
stride_w_ = pool_param.stride_w();
}
if (global_pooling_) {
CHECK(pad_h_ == 0 && pad_w_ == 0 && stride_h_ == 1 && stride_w_ == 1)
<< "With Global_pooling: true; only pad = 0 and stride = 1";
}
if (pad_h_ != 0 || pad_w_ != 0) {
CHECK(this->layer_param_.pooling_param().pool()
== PoolingParameter_PoolMethod_AVE
Expand All @@ -68,6 +84,10 @@ void PoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
width_ = bottom[0]->width();
if (global_pooling_) {
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
}
pooled_height_ = static_cast<int>(ceil(static_cast<float>(
height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
pooled_width_ = static_cast<int>(ceil(static_cast<float>(
Expand Down
3 changes: 3 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ message PoolingParameter {
CUDNN = 2;
}
optional Engine engine = 11 [default = DEFAULT];
// If global_pooling then it will pool over the size of the bottom by doing
// kernel_h = bottom->height and kernel_w = bottom->width
optional bool global_pooling = 12 [default = false];
}

// Message that stores parameters used by PowerLayer
Expand Down
14 changes: 14 additions & 0 deletions src/caffe/test/test_pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,20 @@ TYPED_TEST(PoolingLayerTest, TestSetupPadded) {
EXPECT_EQ(this->blob_top_->width(), 3);
}

TYPED_TEST(PoolingLayerTest, TestSetupGlobalPooling) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_global_pooling(true);
pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
PoolingLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
EXPECT_EQ(this->blob_top_->height(), 1);
EXPECT_EQ(this->blob_top_->width(), 1);
}

/*
TYPED_TEST(PoolingLayerTest, PrintBackward) {
LayerParameter layer_param;
Expand Down

0 comments on commit 576931e

Please sign in to comment.