Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

batchnorm specify channel axis and performance optimizations for batchnorm #6411

Merged
merged 57 commits into from
Jun 5, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
8751b1c
Add channel_axis to batch norm, performance improvements
May 23, 2017
eff7fe7
rearrange tests a bit
May 23, 2017
c1b7bf5
rearrange tests a bit
May 23, 2017
cca8d09
CR changes
May 23, 2017
8af48f9
cpp package link issue
May 23, 2017
91748ab
Fix: MSVC wants all parallel omp to be int
May 24, 2017
bfdbd37
CR comments, expand legal negative axes
May 24, 2017
624b75c
lint
May 24, 2017
2043260
Merge branch 'master' into channelaxis_pr
cjolivier01 May 24, 2017
c885f7e
lint
May 24, 2017
13345c8
Merge branch 'master' into channelaxis_pr
cjolivier01 May 24, 2017
71c67ff
Merge branch 'master' into channelaxis_pr
cjolivier01 May 24, 2017
d90e89b
Merge branch 'master' into channelaxis_pr
cjolivier01 May 24, 2017
5d01f7a
Fix download link (#6431)
kevinthesun May 24, 2017
2d5f3ec
Add release note (#6434)
kevinthesun May 25, 2017
5136af5
Fixing tutorials. (#6436)
pracheer May 25, 2017
dd68848
Formatting fixes (#6433)
Roshrini May 25, 2017
b183142
doc bash 2-5, for pack, unpack, pack_img and unpack_img (#6140)
jiayue666 May 25, 2017
8d2b046
fixing the early stop for maximize = T (#5915)
May 25, 2017
8b8b0c4
Improve style (#6445)
kevinthesun May 25, 2017
ddc80b6
Correction (#6444)
Roshrini May 25, 2017
2277640
Update documentation for MXNetDataIter in io.py (#6000) (#6113)
danithaca May 25, 2017
f58b5b4
fix member variable name: make them end with underline (#6438)
vsooda May 25, 2017
9d3db06
Fix minor issues with api pages. (#6410)
pracheer May 25, 2017
75d5107
Update documentation for mxnet.ndarray.GridGenerator. (#6430)
indhub May 25, 2017
8431980
Update documentation for deconvolution operation. (#6184)
indhub May 26, 2017
8f5e46d
skip lines that have %matplotlib (#6451)
nswamy May 26, 2017
66ec2fc
Fixing some more broken links before v0.10 release (#6449)
sandeep-krishnamurthy May 26, 2017
46beb61
close #4838 (#6452)
thirdwing May 26, 2017
ae80135
Fix linear regression (#6432)
kevinthesun May 26, 2017
7975c92
Pre-trained model tutorial fixes. (#6453)
pracheer May 26, 2017
1d5783d
Nightly test tutorial (#6447)
kevinthesun May 26, 2017
d9be1c2
[R] captcha example (#6443)
thirdwing May 26, 2017
345301d
skip lines that have %matplotlib (#6459)
nswamy May 26, 2017
78e0653
Fix cudnn_deconv not guarding no_bias (#6456)
reminisce May 26, 2017
d41bf5f
fix batchNorm cpp example (#6454)
vsooda May 26, 2017
b90ac59
Fixing up issues in install guide (#6463)
sandeep-krishnamurthy May 26, 2017
266bb33
Fixing copy code functionality for bash command (#6465)
sandeep-krishnamurthy May 26, 2017
62637cd
Residual unroll (#6397)
szha May 26, 2017
795e892
Linear regression Tutorial link (#6468)
pracheer May 26, 2017
81b1483
bump up version number for release (#6462)
szha May 26, 2017
c1281b4
[R][DOC] update R installation guide (#6457)
thirdwing May 26, 2017
7a006b2
Use sphinx==1.3.5 in Dockerfile.doc (#6470)
nswamy May 27, 2017
10cca47
Add 0.10 release info to README.md and NEWS.md (#6471)
nswamy May 27, 2017
803b00c
Update im2rec.py (#6473)
wenxuanxie May 27, 2017
747f00b
Change Interface of NDArray & TBlob for DLPack Compatible (#6345)
ZihengJiang May 30, 2017
47b6876
change 'channel_axis' parameter to 'axis'
May 30, 2017
a93c581
Merge branch 'master' into channelaxis_pr
cjolivier01 May 30, 2017
cb6d604
Change DEFAULT_CHANNEL_AXIS to DEFAULT_AXIS
May 30, 2017
483e9a8
wait for dlpack PR to go through
May 30, 2017
333724e
Merge branch 'master' into channelaxis_pr
cjolivier01 May 31, 2017
ff6fe48
Trigger build
May 31, 2017
c128d15
Merge branch 'master' into channelaxis_pr
cjolivier01 May 31, 2017
4761597
Merge branch 'master' into channelaxis_pr
cjolivier01 Jun 1, 2017
7547385
Merge branch 'master' into channelaxis_pr
cjolivier01 Jun 2, 2017
95b0766
Merge branch 'master' into channelaxis_pr
cjolivier01 Jun 2, 2017
4365243
Merge branch 'master' into channelaxis_pr
piiswrong Jun 5, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
#include <dmlc/logging.h>
#include <mshadow/base.h>

#if MXNET_USE_CUDA

/*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */
#ifdef __JETBRAINS_IDE__
#define __CUDACC__ 1
Expand All @@ -22,12 +20,14 @@
inline void __syncthreads() {}
inline void __threadfence_block() {}
template<class T> inline T __clz(const T val) { return val; }
struct __cuda_fake_struct { int x; int y; };
struct __cuda_fake_struct { int x; int y; int z; };
extern __cuda_fake_struct blockDim;
extern __cuda_fake_struct threadIdx;
extern __cuda_fake_struct blockIdx;
#endif

#if MXNET_USE_CUDA

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <curand.h>
Expand Down
144 changes: 138 additions & 6 deletions src/operator/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data
enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states
} // namespace batchnorm

constexpr int DEFAULT_CHANNEL_AXIS = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go into batchnorm name space

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


/*! \brief Parameters for BatchNoram operator */
struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
float eps;
float momentum;
bool fix_gamma;
bool use_global_stats;
bool output_mean_var;
int channel_axis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

channel_axis -> axis

Copy link
Member Author

@cjolivier01 cjolivier01 May 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename the argument to axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

bool cudnn_off;
DMLC_DECLARE_PARAMETER(BatchNormParam) {
DMLC_DECLARE_FIELD(eps).set_default(1e-3f)
Expand All @@ -54,6 +57,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
"This will force change batch-norm into a scale shift operator.");
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
.describe("Output All,normal mean and var");
DMLC_DECLARE_FIELD(channel_axis).set_default(DEFAULT_CHANNEL_AXIS)
.describe("Specify which shape axis the channel is specified");
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
.describe("Do not select CUDNN operator, if available");
}
Expand Down Expand Up @@ -207,21 +212,26 @@ class BatchNormProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const TShape &dshape = in_shape->at(0);

CHECK_GE(param_.channel_axis, -1) << "Invalid channel axis: " << param_.channel_axis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use proper parsing: axis = axis + ndim()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis = -1 means first dim from right, axis=-2 means second dim from the right. You are only special casing -1 here. We need to support all

Copy link
Member Author

@cjolivier01 cjolivier01 May 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. I wasn't aware of this requirement. I was under the impression it was like some other system I was looking at where there's a "first" flag and "last" flag only (Torch, maybe?) -- 0 would be first, -1 would be last. Ok, I will change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const int channelCount = param_.channel_axis == -1
? dshape[dshape.ndim() - 1] : dshape[param_.channel_axis];

if (dshape.ndim() == 0) {
return false;
}

in_shape->at(1) = TShape(Shape1(dshape[1]));
in_shape->at(2) = TShape(Shape1(dshape[1]));
in_shape->at(1) = TShape(Shape1(channelCount));
in_shape->at(2) = TShape(Shape1(channelCount));

out_shape->clear();
out_shape->push_back(dshape); // kOut
out_shape->push_back(Shape1(dshape[1])); // kMean
out_shape->push_back(Shape1(dshape[1])); // kVar
out_shape->push_back(Shape1(channelCount)); // kMean
out_shape->push_back(Shape1(channelCount)); // kVar

aux_shape->clear();
aux_shape->push_back(Shape1(dshape[1])); // kMovingMean
aux_shape->push_back(Shape1(dshape[1])); // kMovingVar
aux_shape->push_back(Shape1(channelCount)); // kMovingMean
aux_shape->push_back(Shape1(channelCount)); // kMovingVar
return true;
}

Expand Down Expand Up @@ -329,6 +339,128 @@ class BatchNormProp : public OperatorProperty {
BatchNormParam param_;
}; // class BatchNormProp

namespace batchnorm {

#if defined(__CUDACC__)
#define __bn_hostonly__ __host__
#define __bn_hostdevinl__ __host__ __device__ __forceinline__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These macros already exists in mshadow.
for example this is MSHADOW_XINLINE

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

#define __bn_localinline__ __forceinline__
#else
#define __bn_hostonly__
#define __bn_hostdevinl__ inline
#define __bn_localinline__ inline
#endif

template<typename DType>
class BNTensor3 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why define this instead of using mshadow::Tensor<xpu, 3>.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You told me not to use mshadow::Tensor, so I've never inspected it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I? What was the reason? I don't remember

Copy link
Member Author

@cjolivier01 cjolivier01 May 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't see how mshadow::Tensor<xpu, 3> knows about channel axis, and along with my member functions, mshadow::Tensor doesn't offer any added value that I can see

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was referring to don't use mshadow's template evaluation features. Like Tensor x; x = 1;
Only using the data structure is fine.

But you can leave it like this for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

enum { OUTER, CHANNEL, INNER, COUNT };

public:
__bn_hostonly__ inline BNTensor3(const TBlob& blob, const int indexOfChannel)
: dptr_(blob.dptr<DType>())
, indexOfChannel_(indexOfChannel == -1 ? (blob.shape_.ndim() - 1) : indexOfChannel) {
shape_[OUTER] = 1;
for (size_t i = 0; i < indexOfChannel_; ++i) {
shape_[OUTER] *= blob.shape_[i];
}
shape_[CHANNEL] = blob.shape_[indexOfChannel_];
shape_[INNER] = 1;
for (size_t i = indexOfChannel_ + 1, n = blob.shape_.ndim(); i < n; ++i) {
shape_[INNER] *= blob.shape_[i];
}
}

__bn_hostonly__ inline BNTensor3(DType *p, const TShape& shape, const int indexOfChannel)
: dptr_(p)
, indexOfChannel_(indexOfChannel == -1 ? (shape.ndim() - 1) : indexOfChannel) {
shape_[OUTER] = 1;
for (size_t i = 0; i < indexOfChannel_; ++i) {
shape_[OUTER] *= shape[i];
}
shape_[CHANNEL] = shape[indexOfChannel_];
shape_[INNER] = 1;
for (size_t i = indexOfChannel_ + 1, n = shape.ndim(); i < n; ++i) {
shape_[INNER] *= shape[i];
}
}

__bn_localinline__ bool IsEmpty() const {
return dptr_ == nullptr;
}

__bn_hostdevinl__ size_t Size() const {
size_t n = 1;
for (int i = 0; i < COUNT; ++i) {
n *= shape_[i];
}
return n;
}

__bn_hostdevinl__ size_t ChannelCount() const {
return shape_[CHANNEL];
}

__bn_hostdevinl__ size_t OuterSize() const {
return shape_[OUTER];
}

__bn_hostdevinl__ size_t InnerSize() const {
return shape_[INNER];
}

/*! \brief start of a given channel's spatial data */
__bn_hostdevinl__ size_t StartOffset(const size_t channel) const {
return channel * InnerSize();
}

/*! \brief This is the amount to skip to next same-channel data
* This is the number of bytes to skip from one past the end of the current spatial data
* to the next start of the same channel's "spatial data"
* It is assume that the pointer being calculated points just beyond the
* end of the last blobk of spatial data
* i.e. RGBRGB <-- 2
* RRGGBB <-- 4
**/
__bn_hostdevinl__ size_t SkipLengthToNextSameChannelData() const {
return (ChannelCount() - 1) * InnerSize();
}

__bn_hostdevinl__ size_t offset(const size_t outer,
const size_t channel,
const size_t i) const {
const size_t spatial_size = InnerSize();
const size_t skip_length = SkipLengthToNextSameChannelData();
size_t off = StartOffset(channel);
off += outer * shape_[CHANNEL] * shape_[INNER];
const size_t skips = i / spatial_size;
off += (1 + skip_length) * skips;
off += i % spatial_size;
return off;
}

__bn_hostdevinl__ DType& get_ref(const size_t batch,
const size_t channel,
const size_t i) {
const size_t off = offset(batch, channel, i);
return dptr_[off];
}

__bn_hostdevinl__ const DType& get_ref(const size_t batch,
const size_t channel,
const size_t i) const {
const size_t off = offset(batch, channel, i);
return dptr_[off];
}

DType *dptr_;
size_t indexOfChannel_;
size_t shape_[COUNT];
};

extern volatile bool disable_mkl;

} // namespace batchnorm

#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet
Expand Down
Loading