Skip to content

Commit

Permalink
Merge pull request apache#21 from GodBlessZhk/ts
Browse files Browse the repository at this point in the history
revert roi_pooling op to the original version
  • Loading branch information
winstywang authored Aug 30, 2017
2 parents 544d34c + 6701876 commit e5f4b37
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 50 deletions.
36 changes: 18 additions & 18 deletions src/operator/roi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ class ROIPoolingOp : public Operator {
size_t expected = 2;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), expected);
CHECK_EQ(out_data[roipool::kOut].shape_[1], in_data[roipool::kBox].shape_[1]);
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[1], in_data[roipool::kBox].shape_[1]);
CHECK_EQ(out_data[roipool::kOut].shape_[0], in_data[roipool::kBox].shape_[0]);
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]);
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 4, DType> data = in_data[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 3, DType> bbox = in_data[roipool::kBox].get<xpu, 3, DType>(s);
Tensor<xpu, 5, DType> out = out_data[roipool::kOut].get<xpu, 5, DType>(s);
Tensor<xpu, 5, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 5, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> out = out_data[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
CHECK_EQ(data.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand All @@ -88,19 +88,19 @@ class ROIPoolingOp : public Operator {
size_t expected = 2;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), expected);
CHECK_EQ(out_grad[roipool::kOut].shape_[1], in_data[roipool::kBox].shape_[1]);
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[1], in_data[roipool::kBox].shape_[1]);
CHECK_EQ(out_grad[roipool::kOut].shape_[0], in_data[roipool::kBox].shape_[0]);
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]);
CHECK_NE(req[roipool::kData], kWriteInplace) <<
"ROIPooling: Backward doesn't support kWriteInplace.";
CHECK_NE(req[roipool::kBox], kWriteInplace) <<
"ROIPooling: Backward doesn't support kWriteInplace.";
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 5, DType> grad_out = out_grad[roipool::kOut].get<xpu, 5, DType>(s);
Tensor<xpu, 3, DType> bbox = in_data[roipool::kBox].get<xpu, 3, DType>(s);
Tensor<xpu, 5, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 5, DType>(s);
Tensor<xpu, 4, DType> grad_out = out_grad[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad_in = in_grad[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 3, DType> grad_roi = in_grad[roipool::kBox].get<xpu, 3, DType>(s);
Tensor<xpu, 2, DType> grad_roi = in_grad[roipool::kBox].get<xpu, 2, DType>(s);
CHECK_EQ(grad_out.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(max_idx.CheckContiguous(), true);
Expand Down Expand Up @@ -161,18 +161,18 @@ class ROIPoolingProp : public OperatorProperty {
TShape dshape = in_shape->at(roipool::kData);
CHECK_EQ(dshape.ndim(), 4U) << "data should be a 4D tensor";

// bbox: [batch_size, num_rois, 5]
// bbox: [num_rois, 5]
TShape bshape = in_shape->at(roipool::kBox);
CHECK_EQ(bshape.ndim(), 3U) << "bbox should be a 3D tensor of shape [batch, num_rois, 5]";
CHECK_EQ(bshape[2], 5U) << "bbox should be a 3D tensor of shape [batch, num_rois, 5]";
CHECK_EQ(bshape.ndim(), 2U) << "bbox should be a 2D tensor of shape [batch, 5]";
CHECK_EQ(bshape[1], 5U) << "bbox should be a 2D tensor of shape [batch, 5]";

// out: [batch_size, num_rois, c, pooled_h, pooled_w]
// max_idx: [batch_size, num_rois, c, pooled_h, pooled_w]
// out: [num_rois, c, pooled_h, pooled_w]
// max_idx: [num_rois, c, pooled_h, pooled_w]
out_shape->clear();
out_shape->push_back(
Shape5(dshape[0], bshape[1], dshape[1], param_.pooled_size[0], param_.pooled_size[1]));
Shape4(bshape[0], dshape[1], param_.pooled_size[0], param_.pooled_size[1]));
out_shape->push_back(
Shape5(dshape[0], bshape[1], dshape[1], param_.pooled_size[0], param_.pooled_size[1]));
Shape4(bshape[0], dshape[1], param_.pooled_size[0], param_.pooled_size[1]));
return true;
}

Expand Down
30 changes: 15 additions & 15 deletions src/operator/roi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ using std::ceil;

namespace mshadow {
template<typename Dtype>
inline void ROIPoolForward(const Tensor<cpu, 5, Dtype> &out,
inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,
const Tensor<cpu, 4, Dtype> &data,
const Tensor<cpu, 3, Dtype> &bbox,
const Tensor<cpu, 5, Dtype> &max_idx,
const Tensor<cpu, 2, Dtype> &bbox,
const Tensor<cpu, 4, Dtype> &max_idx,
const float spatial_scale_,
const float pad_ratio_) {
const Dtype *bottom_data = data.dptr_;
Expand All @@ -31,10 +31,10 @@ inline void ROIPoolForward(const Tensor<cpu, 5, Dtype> &out,
const int channels_ = data.size(1);
const int height_ = data.size(2);
const int width_ = data.size(3);
const int pooled_height_ = out.size(3);
const int pooled_width_ = out.size(4);
const int pooled_height_ = out.size(2);
const int pooled_width_ = out.size(3);

const int num_rois = bbox.size(1);
const int num_rois = bbox.size(0);
const int batch_size = data.size(0);
const int data_size = data.size(1) * data.size(2) * data.size(3);
// For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
Expand Down Expand Up @@ -101,21 +101,21 @@ inline void ROIPoolForward(const Tensor<cpu, 5, Dtype> &out,
}
// Increment all data pointers by one channel
batch_data += data.size(2) * data.size(3);
top_data += out.size(3) * out.size(4);
argmax_data += max_idx.size(3) * max_idx.size(4);
top_data += out.size(2) * out.size(3);
argmax_data += max_idx.size(2) * max_idx.size(3);
}
// Increment ROI data pointer
bottom_rois += bbox.size(2);
bottom_rois += bbox.size(1);
}

return;
}

template<typename Dtype>
inline void ROIPoolBackwardAcc(const Tensor<cpu, 4, Dtype> &in_grad,
const Tensor<cpu, 5, Dtype> &out_grad,
const Tensor<cpu, 3, Dtype> &bbox,
const Tensor<cpu, 5, Dtype> &max_idx,
const Tensor<cpu, 4, Dtype> &out_grad,
const Tensor<cpu, 2, Dtype> &bbox,
const Tensor<cpu, 4, Dtype> &max_idx,
const float spatial_scale_,
const float pad_ratio_) {
const Dtype *top_diff = out_grad.dptr_;
Expand All @@ -127,10 +127,10 @@ inline void ROIPoolBackwardAcc(const Tensor<cpu, 4, Dtype> &in_grad,
const int channels_ = in_grad.size(1);
const int height_ = in_grad.size(2);
const int width_ = in_grad.size(3);
const int pooled_height_ = out_grad.size(4);
const int pooled_width_ = out_grad.size(5);
const int pooled_height_ = out_grad.size(2);
const int pooled_width_ = out_grad.size(3);

const int num_rois = bbox.size(1);
const int num_rois = bbox.size(0);

for (int b = 0; b < batch_size_; ++b) {
for (int c = 0; c < channels_; ++c) {
Expand Down
34 changes: 17 additions & 17 deletions src/operator/roi_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data,
}

template<typename Dtype>
inline void ROIPoolForward(const Tensor<gpu, 5, Dtype> &out,
inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
const Tensor<gpu, 4, Dtype> &data,
const Tensor<gpu, 3, Dtype> &bbox,
const Tensor<gpu, 5, Dtype> &max_idx,
const Tensor<gpu, 2, Dtype> &bbox,
const Tensor<gpu, 4, Dtype> &max_idx,
const float spatial_scale,
const float pad_ratio) {
const Dtype *bottom_data = data.dptr_;
Expand All @@ -103,8 +103,8 @@ inline void ROIPoolForward(const Tensor<gpu, 5, Dtype> &out,
const int channels = data.size(1);
const int height = data.size(2);
const int width = data.size(3);
const int pooled_height = out.size(3);
const int pooled_width = out.size(4);
const int pooled_height = out.size(2);
const int pooled_width = out.size(3);
const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim);
dim3 dimBlock(kMaxThreadsPerBlock);
Expand Down Expand Up @@ -195,22 +195,22 @@ __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff,

template<typename Dtype>
inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
const Tensor<gpu, 5, Dtype> &out_grad,
const Tensor<gpu, 3, Dtype> &bbox,
const Tensor<gpu, 5, Dtype> &max_idx,
const Tensor<gpu, 4, Dtype> &out_grad,
const Tensor<gpu, 2, Dtype> &bbox,
const Tensor<gpu, 4, Dtype> &max_idx,
const float spatial_scale,
const float pad_ratio) {
const Dtype *top_diff = out_grad.dptr_;
const Dtype *bottom_rois = bbox.dptr_;
Dtype *bottom_diff = in_grad.dptr_;
Dtype *argmax_data = max_idx.dptr_;
const int count = in_grad.shape_.Size();
const int num_rois = bbox.size(0) * bbox.size(1);
const int num_rois = bbox.size(0);
const int channels = in_grad.size(1);
const int height = in_grad.size(2);
const int width = in_grad.size(3);
const int pooled_height = out_grad.size(3);
const int pooled_width = out_grad.size(4);
const int pooled_height = out_grad.size(2);
const int pooled_width = out_grad.size(3);
const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim);
dim3 dimBlock(kMaxThreadsPerBlock);
Expand All @@ -224,20 +224,20 @@ inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
} // namespace cuda

template<typename Dtype>
inline void ROIPoolForward(const Tensor<gpu, 5, Dtype> &out,
inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
const Tensor<gpu, 4, Dtype> &data,
const Tensor<gpu, 3, Dtype> &bbox,
const Tensor<gpu, 5, Dtype> &max_idx,
const Tensor<gpu, 2, Dtype> &bbox,
const Tensor<gpu, 4, Dtype> &max_idx,
const float spatial_scale,
const float pad_ratio) {
cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale, pad_ratio);
}

template<typename Dtype>
inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
const Tensor<gpu, 5, Dtype> &out_grad,
const Tensor<gpu, 3, Dtype> &bbox,
const Tensor<gpu, 5, Dtype> &max_idx,
const Tensor<gpu, 4, Dtype> &out_grad,
const Tensor<gpu, 2, Dtype> &bbox,
const Tensor<gpu, 4, Dtype> &max_idx,
const float spatial_scale,
const float pad_ratio) {
cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale, pad_ratio);
Expand Down

0 comments on commit e5f4b37

Please sign in to comment.