Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f8ed533

Browse files
haojin2eric-haibin-lin
authored andcommitted
add cudnn_off parameter to SpatialTransformer Op and fix the inconsistency between CPU & GPU code (#12557)
1 parent e213286 commit f8ed533

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

src/operator/spatial_transformer-inl.h

+13-10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
5454
TShape target_shape;
5555
int transform_type;
5656
int sampler_type;
57+
dmlc::optional<bool> cudnn_off;
5758
DMLC_DECLARE_PARAMETER(SpatialTransformerParam) {
5859
int shape[] = {0, 0};
5960
DMLC_DECLARE_FIELD(target_shape).set_default(TShape(shape, shape + 2))
@@ -62,6 +63,8 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
6263
.describe("transformation type");
6364
DMLC_DECLARE_FIELD(sampler_type).add_enum("bilinear", st::kBilinear)
6465
.describe("sampling type");
66+
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>())
67+
.describe("whether to turn cudnn off");
6568
}
6669
};
6770

@@ -101,11 +104,11 @@ class SpatialTransformerOp : public Operator {
101104
}
102105
Copy(grid_dst, workspace, grid_dst.stream_);
103106
for (index_t batch = 0; batch < data.size(0); batch++) {
104-
if (param_.transform_type == st::kAffine) {
105-
// Legacy approach shown here for comparison:
106-
// grid_src[batch] = dot(loc[batch], grid_dst);
107-
linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s);
108-
}
107+
if (param_.transform_type == st::kAffine) {
108+
// Legacy approach shown here for comparison:
109+
// grid_src[batch] = dot(loc[batch], grid_dst);
110+
linalg_gemm(loc[batch], grid_dst, grid_src[batch], false, false, s);
111+
}
109112
}
110113
if (param_.sampler_type == st::kBilinear) {
111114
BilinearSamplingForward(out, data, grid_src);
@@ -136,11 +139,11 @@ class SpatialTransformerOp : public Operator {
136139
BilinearSamplingBackward(gdata, grid_src, grad, data);
137140
}
138141
for (index_t batch = 0; batch < data.size(0); batch++) {
139-
if (param_.transform_type == st::kAffine) {
140-
// Legacy approach shown here for comparison:
141-
// gloc[batch] = dot(grid_src[batch], grid_dst.T());
142-
linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s);
143-
}
142+
if (param_.transform_type == st::kAffine) {
143+
// Legacy approach shown here for comparison:
144+
// gloc[batch] = dot(grid_src[batch], grid_dst.T());
145+
linalg_gemm(grid_src[batch], grid_dst, gloc[batch], false, true, s);
146+
}
144147
}
145148
}
146149

src/operator/spatial_transformer.cu

+6-1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ __global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
121121
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
122122
atomicAdd((g_input + data_index + i_w),
123123
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
124+
bottom_left_v = *(data + data_index + i_w);
124125
}
125126
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
126127
atomicAdd((g_input + data_index + i_w + 1),
@@ -194,7 +195,11 @@ Operator* CreateOp<gpu>(SpatialTransformerParam param, int dtype) {
194195
Operator *op = NULL;
195196
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
196197
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
197-
op = new CuDNNSpatialTransformerOp<DType>(param);
198+
if (param.cudnn_off.has_value() && param.cudnn_off.value()) {
199+
op = new SpatialTransformerOp<gpu, DType>(param);
200+
} else {
201+
op = new CuDNNSpatialTransformerOp<DType>(param);
202+
}
198203
})
199204
#else
200205
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {

tests/python/gpu/test_operator_gpu.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,6 @@ def test_grid_generator_with_type():
749749
check_consistency(sym, ctx_list, grad_req="add")
750750

751751

752-
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. https://github.com/apache/incubator-mxnet/issues/11839")
753752
@with_seed()
754753
def test_spatial_transformer_with_type():
755754
data = mx.sym.Variable('data')
@@ -758,11 +757,15 @@ def test_spatial_transformer_with_type():
758757
loc = mx.sym.Activation(data=loc, act_type='relu')
759758
loc = mx.sym.FullyConnected(data=loc, num_hidden=6)
760759
sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
761-
transform_type="affine", sampler_type="bilinear")
760+
transform_type="affine", sampler_type="bilinear", cudnn_off=True)
762761
ctx_list = [{'ctx': mx.gpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}},
763762
{'ctx': mx.cpu(0), 'data': (1, 5, 10, 10), 'type_dict': {'data': np.float64}}]
764763
check_consistency(sym, ctx_list)
765764
check_consistency(sym, ctx_list, grad_req="add")
765+
sym = mx.sym.SpatialTransformer(data=data, loc=loc, target_shape=(10, 10),
766+
transform_type="affine", sampler_type="bilinear", cudnn_off=False)
767+
check_consistency(sym, ctx_list)
768+
check_consistency(sym, ctx_list, grad_req="add")
766769

767770

768771
@with_seed()

0 commit comments

Comments
 (0)