@@ -54,6 +54,7 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
54
54
TShape target_shape;
55
55
int transform_type;
56
56
int sampler_type;
57
+ dmlc::optional<bool > cudnn_off;
57
58
DMLC_DECLARE_PARAMETER (SpatialTransformerParam) {
58
59
int shape[] = {0 , 0 };
59
60
DMLC_DECLARE_FIELD (target_shape).set_default (TShape (shape, shape + 2 ))
@@ -62,6 +63,8 @@ struct SpatialTransformerParam : public dmlc::Parameter<SpatialTransformerParam>
62
63
.describe (" transformation type" );
63
64
DMLC_DECLARE_FIELD (sampler_type).add_enum (" bilinear" , st::kBilinear )
64
65
.describe (" sampling type" );
66
+ DMLC_DECLARE_FIELD (cudnn_off).set_default (dmlc::optional<bool >())
67
+ .describe (" whether to turn cudnn off" );
65
68
}
66
69
};
67
70
@@ -101,11 +104,11 @@ class SpatialTransformerOp : public Operator {
101
104
}
102
105
Copy (grid_dst, workspace, grid_dst.stream_ );
103
106
for (index_t batch = 0 ; batch < data.size (0 ); batch++) {
104
- if (param_.transform_type == st::kAffine ) {
105
- // Legacy approach shown here for comparison:
106
- // grid_src[batch] = dot(loc[batch], grid_dst);
107
- linalg_gemm (loc[batch], grid_dst, grid_src[batch], false , false , s);
108
- }
107
+ if (param_.transform_type == st::kAffine ) {
108
+ // Legacy approach shown here for comparison:
109
+ // grid_src[batch] = dot(loc[batch], grid_dst);
110
+ linalg_gemm (loc[batch], grid_dst, grid_src[batch], false , false , s);
111
+ }
109
112
}
110
113
if (param_.sampler_type == st::kBilinear ) {
111
114
BilinearSamplingForward (out, data, grid_src);
@@ -136,11 +139,11 @@ class SpatialTransformerOp : public Operator {
136
139
BilinearSamplingBackward (gdata, grid_src, grad, data);
137
140
}
138
141
for (index_t batch = 0 ; batch < data.size (0 ); batch++) {
139
- if (param_.transform_type == st::kAffine ) {
140
- // Legacy approach shown here for comparison:
141
- // gloc[batch] = dot(grid_src[batch], grid_dst.T());
142
- linalg_gemm (grid_src[batch], grid_dst, gloc[batch], false , true , s);
143
- }
142
+ if (param_.transform_type == st::kAffine ) {
143
+ // Legacy approach shown here for comparison:
144
+ // gloc[batch] = dot(grid_src[batch], grid_dst.T());
145
+ linalg_gemm (grid_src[batch], grid_dst, gloc[batch], false , true , s);
146
+ }
144
147
}
145
148
}
146
149
0 commit comments