Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add inlines
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Oct 20, 2018
1 parent 3c556ca commit 4a12d6a
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions tests/cpp/include/test_mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

using namespace mxnet;

static mkldnn::memory::primitive_desc GetMemPD(const TShape s, int dtype,
inline static mkldnn::memory::primitive_desc GetMemPD(const TShape s, int dtype,
mkldnn::memory::format format) {
mkldnn::memory::dims dims(s.ndim());
for (size_t i = 0; i < dims.size(); i++)
Expand All @@ -46,7 +46,7 @@ static mkldnn::memory::primitive_desc GetMemPD(const TShape s, int dtype,
return mkldnn::memory::primitive_desc(desc, CpuEngine::Get()->get_engine());
}

static mkldnn::memory::primitive_desc GetExpandedMemPD(
inline static mkldnn::memory::primitive_desc GetExpandedMemPD(
mkldnn::memory::primitive_desc pd, float scale, int dim = 0) {
CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input";
nnvm::TShape s(pd.desc().data.ndims);
Expand All @@ -63,7 +63,7 @@ struct TestArrayShapes {
};

// Init arrays with the default layout.
static void InitDefaultArray(NDArray *arr, bool is_rand = false) {
inline static void InitDefaultArray(NDArray *arr, bool is_rand = false) {
const TBlob &blob = arr->data();
mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>();
int size = blob.Size();
Expand All @@ -78,14 +78,14 @@ static void InitDefaultArray(NDArray *arr, bool is_rand = false) {


// Init arrays with the specified layout.
static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::primitive_desc &pd,
inline static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::primitive_desc &pd,
bool is_rand = false) {
InitDefaultArray(arr, is_rand);
arr->MKLDNNDataReorderAsync(pd);
arr->WaitToRead();
}

static bool IsSameShape(mkldnn::memory::primitive_desc pd, TShape shape) {
inline static bool IsSameShape(mkldnn::memory::primitive_desc pd, TShape shape) {
if (pd.desc().data.ndims != shape.ndim()) return false;
for (size_t i = 0; i < shape.ndim(); i++)
if (pd.desc().data.dims[i] != shape[i]) return false;
Expand All @@ -97,7 +97,7 @@ static bool IsSameShape(mkldnn::memory::primitive_desc pd, TShape shape) {
// it's specific for certain array shapes. It covers at least one special format
// for each of the formats: nchw, oihw, goihw.
// To test the logic of the code in NDArray, these formats should be enough.
static std::vector<mkldnn::memory::format> GetMKLDNNFormat(size_t num_dims, int dtype) {
inline static std::vector<mkldnn::memory::format> GetMKLDNNFormat(size_t num_dims, int dtype) {
if (num_dims == 4) {
mkldnn::memory::dims data_dims{1, 3, 224, 224};
mkldnn::memory::desc data_md{data_dims, get_mkldnn_type(dtype),
Expand Down Expand Up @@ -148,7 +148,7 @@ static std::vector<mkldnn::memory::format> GetMKLDNNFormat(size_t num_dims, int
}
}

static TestArrayShapes GetTestArrayShapes() {
inline static TestArrayShapes GetTestArrayShapes() {
int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
std::vector<TShape> shapes;
std::vector<mkldnn::memory::primitive_desc> pds;
Expand Down Expand Up @@ -240,7 +240,7 @@ enum ArrayTypes {
All = 8191,
};

std::string CreateShapeString(int value, int dim) {
inline std::string CreateShapeString(int value, int dim) {
std::stringstream ss;
ss << "(";
for (int i = 0; i < dim; i++) {
Expand All @@ -251,7 +251,7 @@ std::string CreateShapeString(int value, int dim) {
return ss.str();
}

void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
inline void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
TShape t1 = arr1.arr.shape();
TShape t2 = arr2.arr.shape();
std::stringstream ss;
Expand Down Expand Up @@ -281,7 +281,7 @@ void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
*
* num_inputs / dim arguments used to scale shape (used for concat backwards to enlarge input shapes)
*/
std::vector<NDArrayAttrs> GetTestInputArrays(
inline std::vector<NDArrayAttrs> GetTestInputArrays(
int types = ArrayTypes::All, bool rand = false,
int num_inputs = 1, int dim = 0) {
TestArrayShapes tas = GetTestArrayShapes();
Expand Down Expand Up @@ -395,7 +395,7 @@ std::vector<NDArrayAttrs> GetTestInputArrays(
*
* Optional num_inputs / dim args can be passed to modify input shape (used for Concat test)
*/
std::vector<NDArrayAttrs> GetTestOutputArrays(
inline std::vector<NDArrayAttrs> GetTestOutputArrays(
const TShape &shp,
const std::vector<mkldnn::memory::primitive_desc> &pds,
std::vector<float>scale = {1}, bool rand = true, int types = ArrayTypes::All) {
Expand Down Expand Up @@ -506,7 +506,7 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(
* Determines axis ndarrays are concatenated by
* Used to verify concat/concat backwards operator
*/
int GetDim(TShape input_shape, TShape output_shape) {
inline int GetDim(TShape input_shape, TShape output_shape) {
CHECK(input_shape.Size() != output_shape.Size());
for (size_t i = 0; i < input_shape.ndim(); i++) {
if (input_shape[i] != output_shape[i])
Expand All @@ -519,21 +519,21 @@ int GetDim(TShape input_shape, TShape output_shape) {
* Calculates the size of continuous block of array inside larger concatenated array
* Used to verify concat/concat backwards operator
*/
int GetBlockSize(TShape shape, int dim) {
inline int GetBlockSize(TShape shape, int dim) {
int block_size = 1;
for (int i = shape.ndim() - 1; i >= dim; i--)
block_size *= shape[i];
return block_size;
}

int CalculateWidthPoolOutput(int width, int kernel, int padding, int stride) {
inline int CalculateWidthPoolOutput(int width, int kernel, int padding, int stride) {
return (width - kernel + 2 * padding) / stride + 1;
}

using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs)>;

void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
inline void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
Expand All @@ -548,7 +548,7 @@ void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
verify_fn(in_arrs, tmp_outputs);
}

void VerifyCopyResult(const std::vector<NDArray *> &in_arrs,
inline void VerifyCopyResult(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs) {
NDArray tmp1 = in_arrs[0]->Reorder2Default();
NDArray tmp2 = out_arrs[0]->Reorder2Default();
Expand All @@ -559,7 +559,7 @@ void VerifyCopyResult(const std::vector<NDArray *> &in_arrs,
tmp1.shape().Size() * sizeof(mshadow::default_real_t)), 0);
}

void VerifySumResult(const std::vector<NDArray *> &in_arrs,
inline void VerifySumResult(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs) {
NDArray in1 = in_arrs[0]->Reorder2Default();
NDArray in2 = in_arrs[1]->Reorder2Default();
Expand All @@ -574,5 +574,5 @@ void VerifySumResult(const std::vector<NDArray *> &in_arrs,
ASSERT_EQ(d1[i] + d2[i], o[i]);
}

#endif
#endif // MXNET_USE_MKLDNN
#endif // TEST_MKLDNN_H_

0 comments on commit 4a12d6a

Please sign in to comment.