-
Notifications
You must be signed in to change notification settings - Fork 6.8k
batchnorm specify channel axis and performance optimizations for batchnorm #6411
Changes from 3 commits
8751b1c
eff7fe7
c1b7bf5
cca8d09
8af48f9
91748ab
bfdbd37
624b75c
2043260
c885f7e
13345c8
71c67ff
d90e89b
5d01f7a
2d5f3ec
5136af5
dd68848
b183142
8d2b046
8b8b0c4
ddc80b6
2277640
f58b5b4
9d3db06
75d5107
8431980
8f5e46d
66ec2fc
46beb61
ae80135
7975c92
1d5783d
d9be1c2
345301d
78e0653
d41bf5f
b90ac59
266bb33
62637cd
795e892
81b1483
c1281b4
7a006b2
10cca47
803b00c
747f00b
47b6876
a93c581
cb6d604
483e9a8
333724e
ff6fe48
c128d15
4761597
7547385
95b0766
4365243
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,13 +32,16 @@ enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data | |
enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states | ||
} // namespace batchnorm | ||
|
||
constexpr int DEFAULT_CHANNEL_AXIS = 1; | ||
|
||
/*! \brief Parameters for BatchNoram operator */ | ||
struct BatchNormParam : public dmlc::Parameter<BatchNormParam> { | ||
float eps; | ||
float momentum; | ||
bool fix_gamma; | ||
bool use_global_stats; | ||
bool output_mean_var; | ||
int channel_axis; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. channel_axis -> axis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename the argument to axis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to axis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
bool cudnn_off; | ||
DMLC_DECLARE_PARAMETER(BatchNormParam) { | ||
DMLC_DECLARE_FIELD(eps).set_default(1e-3f) | ||
|
@@ -54,6 +57,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> { | |
"This will force change batch-norm into a scale shift operator."); | ||
DMLC_DECLARE_FIELD(output_mean_var).set_default(false) | ||
.describe("Output All,normal mean and var"); | ||
DMLC_DECLARE_FIELD(channel_axis).set_default(DEFAULT_CHANNEL_AXIS) | ||
.describe("Specify which shape axis the channel is specified"); | ||
DMLC_DECLARE_FIELD(cudnn_off).set_default(false) | ||
.describe("Do not select CUDNN operator, if available"); | ||
} | ||
|
@@ -207,21 +212,26 @@ class BatchNormProp : public OperatorProperty { | |
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; | ||
const TShape &dshape = in_shape->at(0); | ||
|
||
CHECK_GE(param_.channel_axis, -1) << "Invalid channel axis: " << param_.channel_axis; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use proper parsing: axis = axis + ndim() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. axis = -1 means first dim from right, axis=-2 means second dim from the right. You are only special casing -1 here. We need to support all There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I see. I wasn't aware of this requirement. I was under the impression it was like some other system I was looking at where there's a "first" flag and "last" flag only (Torch, maybe?) -- 0 would be first, -1 would be last. Ok, I will change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
const int channelCount = param_.channel_axis == -1 | ||
? dshape[dshape.ndim() - 1] : dshape[param_.channel_axis]; | ||
|
||
if (dshape.ndim() == 0) { | ||
return false; | ||
} | ||
|
||
in_shape->at(1) = TShape(Shape1(dshape[1])); | ||
in_shape->at(2) = TShape(Shape1(dshape[1])); | ||
in_shape->at(1) = TShape(Shape1(channelCount)); | ||
in_shape->at(2) = TShape(Shape1(channelCount)); | ||
|
||
out_shape->clear(); | ||
out_shape->push_back(dshape); // kOut | ||
out_shape->push_back(Shape1(dshape[1])); // kMean | ||
out_shape->push_back(Shape1(dshape[1])); // kVar | ||
out_shape->push_back(Shape1(channelCount)); // kMean | ||
out_shape->push_back(Shape1(channelCount)); // kVar | ||
|
||
aux_shape->clear(); | ||
aux_shape->push_back(Shape1(dshape[1])); // kMovingMean | ||
aux_shape->push_back(Shape1(dshape[1])); // kMovingVar | ||
aux_shape->push_back(Shape1(channelCount)); // kMovingMean | ||
aux_shape->push_back(Shape1(channelCount)); // kMovingVar | ||
return true; | ||
} | ||
|
||
|
@@ -329,6 +339,128 @@ class BatchNormProp : public OperatorProperty { | |
BatchNormParam param_; | ||
}; // class BatchNormProp | ||
|
||
namespace batchnorm { | ||
|
||
#if defined(__CUDACC__) | ||
#define __bn_hostonly__ __host__ | ||
#define __bn_hostdevinl__ __host__ __device__ __forceinline__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These macros already exists in mshadow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
#define __bn_localinline__ __forceinline__ | ||
#else | ||
#define __bn_hostonly__ | ||
#define __bn_hostdevinl__ inline | ||
#define __bn_localinline__ inline | ||
#endif | ||
|
||
template<typename DType> | ||
class BNTensor3 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why define this instead of using mshadow::Tensor<xpu, 3>. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You told me not to use mshadow::Tensor, so I've never inspected it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did I? What was the reason? I don't remember There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also don't see how mshadow::Tensor<xpu, 3> knows about channel axis, and along with my member functions, mshadow::Tensor doesn't offer any added value that I can see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I was referring to don't use mshadow's template evaluation features. Like Tensor x; x = 1; But you can leave it like this for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
enum { OUTER, CHANNEL, INNER, COUNT }; | ||
|
||
public: | ||
__bn_hostonly__ inline BNTensor3(const TBlob& blob, const int indexOfChannel) | ||
: dptr_(blob.dptr<DType>()) | ||
, indexOfChannel_(indexOfChannel == -1 ? (blob.shape_.ndim() - 1) : indexOfChannel) { | ||
shape_[OUTER] = 1; | ||
for (size_t i = 0; i < indexOfChannel_; ++i) { | ||
shape_[OUTER] *= blob.shape_[i]; | ||
} | ||
shape_[CHANNEL] = blob.shape_[indexOfChannel_]; | ||
shape_[INNER] = 1; | ||
for (size_t i = indexOfChannel_ + 1, n = blob.shape_.ndim(); i < n; ++i) { | ||
shape_[INNER] *= blob.shape_[i]; | ||
} | ||
} | ||
|
||
__bn_hostonly__ inline BNTensor3(DType *p, const TShape& shape, const int indexOfChannel) | ||
: dptr_(p) | ||
, indexOfChannel_(indexOfChannel == -1 ? (shape.ndim() - 1) : indexOfChannel) { | ||
shape_[OUTER] = 1; | ||
for (size_t i = 0; i < indexOfChannel_; ++i) { | ||
shape_[OUTER] *= shape[i]; | ||
} | ||
shape_[CHANNEL] = shape[indexOfChannel_]; | ||
shape_[INNER] = 1; | ||
for (size_t i = indexOfChannel_ + 1, n = shape.ndim(); i < n; ++i) { | ||
shape_[INNER] *= shape[i]; | ||
} | ||
} | ||
|
||
__bn_localinline__ bool IsEmpty() const { | ||
return dptr_ == nullptr; | ||
} | ||
|
||
__bn_hostdevinl__ size_t Size() const { | ||
size_t n = 1; | ||
for (int i = 0; i < COUNT; ++i) { | ||
n *= shape_[i]; | ||
} | ||
return n; | ||
} | ||
|
||
__bn_hostdevinl__ size_t ChannelCount() const { | ||
return shape_[CHANNEL]; | ||
} | ||
|
||
__bn_hostdevinl__ size_t OuterSize() const { | ||
return shape_[OUTER]; | ||
} | ||
|
||
__bn_hostdevinl__ size_t InnerSize() const { | ||
return shape_[INNER]; | ||
} | ||
|
||
/*! \brief start of a given channel's spatial data */ | ||
__bn_hostdevinl__ size_t StartOffset(const size_t channel) const { | ||
return channel * InnerSize(); | ||
} | ||
|
||
/*! \brief This is the amount to skip to next same-channel data | ||
* This is the number of bytes to skip from one past the end of the current spatial data | ||
* to the next start of the same channel's "spatial data" | ||
* It is assume that the pointer being calculated points just beyond the | ||
* end of the last blobk of spatial data | ||
* i.e. RGBRGB <-- 2 | ||
* RRGGBB <-- 4 | ||
**/ | ||
__bn_hostdevinl__ size_t SkipLengthToNextSameChannelData() const { | ||
return (ChannelCount() - 1) * InnerSize(); | ||
} | ||
|
||
__bn_hostdevinl__ size_t offset(const size_t outer, | ||
const size_t channel, | ||
const size_t i) const { | ||
const size_t spatial_size = InnerSize(); | ||
const size_t skip_length = SkipLengthToNextSameChannelData(); | ||
size_t off = StartOffset(channel); | ||
off += outer * shape_[CHANNEL] * shape_[INNER]; | ||
const size_t skips = i / spatial_size; | ||
off += (1 + skip_length) * skips; | ||
off += i % spatial_size; | ||
return off; | ||
} | ||
|
||
__bn_hostdevinl__ DType& get_ref(const size_t batch, | ||
const size_t channel, | ||
const size_t i) { | ||
const size_t off = offset(batch, channel, i); | ||
return dptr_[off]; | ||
} | ||
|
||
__bn_hostdevinl__ const DType& get_ref(const size_t batch, | ||
const size_t channel, | ||
const size_t i) const { | ||
const size_t off = offset(batch, channel, i); | ||
return dptr_[off]; | ||
} | ||
|
||
DType *dptr_; | ||
size_t indexOfChannel_; | ||
size_t shape_[COUNT]; | ||
}; | ||
|
||
extern volatile bool disable_mkl; | ||
|
||
} // namespace batchnorm | ||
|
||
#endif // DMLC_USE_CXX11 | ||
} // namespace op | ||
} // namespace mxnet | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go into batchnorm name space
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok