Skip to content

Commit

Permalink
Merge pull request #6659 from qingqing01/mobile_mem
Browse files Browse the repository at this point in the history
Reduce memory usage in conv layer and RoI layer for mobile inference.
  • Loading branch information
qingqing01 authored Dec 15, 2017
2 parents c907654 + 3496092 commit 480a544
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
5 changes: 5 additions & 0 deletions paddle/function/GemmConvOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class GemmConvFunction : public ConvFunctionBase {
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
#ifdef PADDLE_MOBILE_INFERENCE
if (Device == DEVICE_TYPE_CPU) {
delete memory_;
}
#endif
}
};

Expand Down
27 changes: 18 additions & 9 deletions paddle/gserver/layers/ROIPoolLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) {
size_t poolChannelOffset = pooledHeight_ * pooledWidth_;

real* outputData = outputValue->getData();
Matrix::resizeOrCreate(maxIdxs_,
numROIs,
channels_ * pooledHeight_ * pooledWidth_,
false,
false);
real* argmaxData = maxIdxs_->getData();
real* argmaxData = nullptr;
if (passType != PASS_TEST) {
Matrix::resizeOrCreate(maxIdxs_,
numROIs,
channels_ * pooledHeight_ * pooledWidth_,
false,
false);
argmaxData = maxIdxs_->getData();
}

for (size_t n = 0; n < numROIs; ++n) {
// the first five elememts of each RoI should be:
Expand Down Expand Up @@ -128,22 +131,28 @@ void ROIPoolLayer::forward(PassType passType) {
bool isEmpty = (hend <= hstart) || (wend <= wstart);
size_t poolIndex = ph * pooledWidth_ + pw;
outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX;
argmaxData[poolIndex] = -1;
if (argmaxData) {
argmaxData[poolIndex] = -1;
}

for (size_t h = hstart; h < hend; ++h) {
for (size_t w = wstart; w < wend; ++w) {
size_t index = h * width_ + w;
if (batchData[index] > outputData[poolIndex]) {
outputData[poolIndex] = batchData[index];
argmaxData[poolIndex] = index;
if (argmaxData) {
argmaxData[poolIndex] = index;
}
}
}
}
}
}
batchData += channelOffset;
outputData += poolChannelOffset;
argmaxData += poolChannelOffset;
if (argmaxData) {
argmaxData += poolChannelOffset;
}
}
bottomROIs += roiOffset;
}
Expand Down

0 comments on commit 480a544

Please sign in to comment.