Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-517] add sample ratio for ROI Align (#11145)
Browse files Browse the repository at this point in the history
* add sample ratio

* pylint

* increase size limit for bilinearup

* add test case

* fix typo

* rm comments and cpu back
  • Loading branch information
zhanghang1989 authored and zhreshold committed Jun 29, 2018
1 parent ebfc16e commit e892301
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/operator/contrib/bilinear_resize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
int height;
int width;
DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
DMLC_DECLARE_FIELD(height).set_range(1, 1000)
DMLC_DECLARE_FIELD(height).set_range(1, 10000)
.describe("output height (required)");
DMLC_DECLARE_FIELD(width).set_range(1, 1000)
DMLC_DECLARE_FIELD(width).set_range(1, 10000)
.describe("output width (required)");
}
};
Expand Down
3 changes: 3 additions & 0 deletions src/operator/contrib/roi_align-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ enum ROIAlignOpOutputs {kOut};
struct ROIAlignParam : public dmlc::Parameter<ROIAlignParam> {
TShape pooled_size;
float spatial_scale;
int sample_ratio;
DMLC_DECLARE_PARAMETER(ROIAlignParam) {
DMLC_DECLARE_FIELD(pooled_size)
.set_expect_ndim(2).enforce_nonzero()
.describe("ROI Align output roi feature map height and width: (h, w)");
DMLC_DECLARE_FIELD(spatial_scale).set_range(0.0, 1.0)
.describe("Ratio of input feature map height (or w) to raw image height (or w). "
"Equals the reciprocal of total stride in convolutional layers");
DMLC_DECLARE_FIELD(sample_ratio).set_default(-1)
.describe("Optional sampling ratio of ROI align, using adaptive size by default.");
}
};

Expand Down
6 changes: 3 additions & 3 deletions src/operator/contrib/roi_align.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,8 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
DType *top_data = out_data[roialign::kOut].dptr<DType>();

ROIAlignForward<DType>(count, bottom_data, param.spatial_scale, channels,
height, width, pooled_height, pooled_width, -1, bottom_rois,
rois_cols, top_data);
height, width, pooled_height, pooled_width, param.sample_ratio,
bottom_rois, rois_cols, top_data);
})
}

Expand Down Expand Up @@ -490,7 +490,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
}
ROIAlignBackward<DType>(count, top_diff, num_rois, param.spatial_scale,
channels, height, width, pooled_height, pooled_width,
-1, grad_in, bottom_rois, rois_cols);
param.sample_ratio, grad_in, bottom_rois, rois_cols);
}
if (kWriteTo == req[roialign::kBox]) {
Fill<false>(s, outputs[1], kWriteTo, static_cast<DType>(0));
Expand Down
21 changes: 2 additions & 19 deletions src/operator/contrib/roi_align.cu
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,6 @@ __device__ void bilinear_interpolate_gradient(
T lx = x - *x_low;
T hy = 1. - ly, hx = 1. - lx;

// reference in forward
// T v1 = bottom_data[*y_low * width + *x_low];
// T v2 = bottom_data[*y_low * width + *x_high];
// T v3 = bottom_data[*y_high * width + *x_low];
// T v4 = bottom_data[*y_high * width + *x_high];
// T val = (w1 * v1 + *w2 * v2 + *w3 * v3 + *w4 * v4);

*w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx;

return;
Expand Down Expand Up @@ -341,16 +334,6 @@ __global__ void RoIAlignBackwardKernel(
offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(
offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
/*
gpu_atomic_add(
static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
gpu_atomic_add(
static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
gpu_atomic_add(
static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
gpu_atomic_add(
static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
*/
} // if
} // ix
} // iy
Expand Down Expand Up @@ -399,7 +382,7 @@ void ROIAlignForwardCompute(const nnvm::NodeAttrs& attrs,
width,
pooled_height,
pooled_width,
-1,
param.sample_ratio,
bottom_rois,
top_data);
})
Expand Down Expand Up @@ -467,7 +450,7 @@ void ROIAlignBackwardCompute(const nnvm::NodeAttrs& attrs,
width,
pooled_height,
pooled_width,
-1,
param.sample_ratio,
grad_in,
bottom_rois);
})
Expand Down
16 changes: 9 additions & 7 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6376,7 +6376,7 @@ def roialign_forward_backward(data, rois, pooled_size, spatial_scale, sampling_r
out[r, c, ph, pw] = val * 1.0 / count
return out, [dx, drois]

def test_roi_align_value():
def test_roi_align_value(sampling_ratio=0):
ctx=default_context()
dtype = np.float32

Expand All @@ -6387,7 +6387,6 @@ def test_roi_align_value():
pooled_size = (3, 4)

spatial_scale = H * 1.0 / dlen
sampling_ratio = 0
data = mx.nd.array(np.arange(N*C*W*H).reshape((N,C,H,W)), ctx=ctx, dtype = dtype)
# data = mx.nd.random.uniform(0, 1, (N, C, H, W), dtype = dtype)
center_xy = mx.nd.random.uniform(0, dlen, (R, 2), ctx=ctx, dtype = dtype)
Expand All @@ -6400,21 +6399,23 @@ def test_roi_align_value():
rois.attach_grad()
with mx.autograd.record():
output = mx.nd.contrib.ROIAlign(data, rois, pooled_size=pooled_size,
spatial_scale=spatial_scale)
spatial_scale=spatial_scale, sample_ratio=sampling_ratio)
dy = mx.nd.random.uniform(-1, 1, (R, C) + pooled_size, ctx=ctx, dtype = dtype)
output.backward(dy)
real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size, spatial_scale, sampling_ratio, dy.asnumpy())
real_output, [dx, drois] = roialign_forward_backward(data.asnumpy(), rois.asnumpy(), pooled_size,
spatial_scale, sampling_ratio, dy.asnumpy())
assert np.allclose(output.asnumpy(), real_output)
# It seems that the precision between Cfloat and Pyfloat is different.
assert np.allclose(data.grad.asnumpy(), dx, atol = 1e-5), np.abs(data.grad.asnumpy() - dx).max()
assert np.allclose(rois.grad.asnumpy(), drois)

# modified from test_roipooling()
def test_roi_align_autograd():
ctx=default_context()
def test_roi_align_autograd(sampling_ratio=0):
ctx = default_context()
data = mx.symbol.Variable(name='data')
rois = mx.symbol.Variable(name='rois')
test = mx.symbol.contrib.ROIAlign(data=data, rois=rois, pooled_size=(4, 4), spatial_scale=1)
test = mx.symbol.contrib.ROIAlign(data=data, rois=rois, pooled_size=(4, 4), spatial_scale=1,
sample_ratio=sampling_ratio)

x1 = np.random.rand(4, 1, 12, 12).astype('float64')
x2 = np.array([[0, 1.1, 1.1, 6.2, 6.2], [2, 6.1, 2.1, 8.2, 11.2],
Expand All @@ -6428,6 +6429,7 @@ def test_roi_align_autograd():
numeric_eps=1e-4, rtol=1e-1, atol=1e-4, ctx=ctx)

test_roi_align_value()
test_roi_align_value(2)
test_roi_align_autograd()


Expand Down

0 comments on commit e892301

Please sign in to comment.