Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[WIP] Fix compilation for large tensor with BLAS=MKL
Browse files Browse the repository at this point in the history
POC for compilation with:
-DUSE_INT64_TENSOR_SIZE=ON -DUSE_BLAS=mkl
It sets internally MKL_USE_ILP64
  • Loading branch information
anko-intel committed Sep 26, 2020
1 parent e2aacce commit b4d19fc
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 187 deletions.
19 changes: 19 additions & 0 deletions 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ extern "C" {
#include <cblas.h>
}
#elif MSHADOW_USE_MKL
#if MSHADOW_INT64_TENSOR_SIZE == 1
// Define MKL_INT here to use exactly the same 64bits integer type definitions.
// If MKL_INT will be not defined here mkl header define it as long long int.
#define MKL_INT int64_t
#define MKL_UINT uint64_t
#endif
#include <mkl_blas.h>
#include <mkl_cblas.h>
#include <mkl_vsl.h>
Expand Down Expand Up @@ -320,6 +326,13 @@ const float kPi = 3.1415926f;
typedef index_t openmp_index_t;
#endif

#if MSHADOW_USE_MKL && not MSHADOW_USE_CUDA
// lapack_index_t could be replaced by index_t and removed when all blas library support large tensor
typedef index_t lapack_index_t;
#else
typedef int lapack_index_t;
#endif

/*! \brief float point type that will be used in default by mshadow */
typedef float default_real_t;

Expand Down Expand Up @@ -447,6 +460,12 @@ struct DataType<bool> {
/*! \brief type enum value for default real type */
const int default_type_flag = DataType<default_real_t>::kFlag;

/*! \brief TypeFlag value for type of indexes */
const int index_type_flag = DataType<index_t>::kFlag;

/*! \brief TypeFlag value for type of indexes */
const int blas_index_type_flag = DataType<lapack_index_t>::kFlag;

/*! layout flag */
enum LayoutFlag {
kNCHW = 0,
Expand Down
7 changes: 7 additions & 0 deletions cmake/ChooseBlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ elseif(BLAS STREQUAL "Open" OR BLAS STREQUAL "open")
add_definitions(-DMSHADOW_USE_MKL=0)
add_definitions(-DMXNET_USE_BLAS_OPEN=1)
elseif(BLAS STREQUAL "MKL" OR BLAS STREQUAL "mkl")
if (USE_INT64_TENSOR_SIZE)
set(MKL_USE_ILP64 ON CACHE BOOL "enable using ILP64 in MKL" FORCE)
else()
if(MKL_USE_ILP64)
message(FATAL_ERROR "MKL_USE_ILP64 cannot be set without USE_INT64_TENSOR_SIZE; Please set USE_INT64_TENSOR_SIZE instead.")
endif()
endif()
find_package(MKL REQUIRED)
include_directories(SYSTEM ${MKL_INCLUDE_DIR})
list(APPEND mshadow_LINKER_LIBS ${MKL_LIBRARIES})
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ typedef mshadow::cpu cpu;
typedef mshadow::gpu gpu;
/*! \brief index type usually use unsigned */
typedef mshadow::index_t index_t;
/*! \brief index type for blas library.*/
typedef mshadow::lapack_index_t lapack_index_t;
/*! \brief data type that will be used to store ndarray */
typedef mshadow::default_real_t real_t;
/*! \brief operator structure from NNVM */
Expand Down
69 changes: 37 additions & 32 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// The following functions differ in signature from the
// MXNET_LAPACK-signature and have to be wrapped.
#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, lapack_index_t m, lapack_index_t n, \
dtype *a, lapack_index_t lda, dtype *tau, \
dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
} \
Expand All @@ -278,9 +278,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GELQF(d, double)

#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, lapack_index_t m, lapack_index_t n, \
dtype *a, lapack_index_t lda, dtype *tau, \
dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
} \
Expand All @@ -291,9 +291,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_ORGLQ(d, double)

#define MXNET_LAPACK_CWRAP_GEQRF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##geqrf(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
inline int MXNET_LAPACK_##prefix##geqrf(int matrix_layout, lapack_index_t m, lapack_index_t n, \
dtype *a, lapack_index_t lda, dtype *tau, \
dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##geqrf(matrix_layout, m, n, a, lda, tau); \
} \
Expand All @@ -304,9 +304,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GEQRF(d, double)

#define MXNET_LAPACK_CWRAP_ORGQR(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, int m, int n, int k, \
dtype *a, int lda, dtype *tau, \
dtype *work, int lwork) { \
inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, lapack_index_t m, lapack_index_t n, \
lapack_index_t k, dtype *a, lapack_index_t lda, \
dtype *tau, dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##orgqr(matrix_layout, m, n, k, a, lda, tau); \
} \
Expand All @@ -322,9 +322,10 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// We also have to allocate at least one DType element as workspace as the
// calling code assumes that the workspace has at least that size.
#define MXNET_LAPACK_CWRAP_SYEVD(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, char uplo, lapack_index_t n, \
dtype *a, lapack_index_t lda, dtype *w, \
dtype *work, lapack_index_t lwork, \
lapack_index_t *iwork, lapack_index_t liwork) { \
if (lwork != -1) { \
char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \
return LAPACKE_##prefix##syevd(LAPACK_COL_MAJOR, 'V', o, n, a, lda, w); \
Expand All @@ -344,9 +345,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// We also have to allocate at least m - 1 DType elements as workspace as the internal
// LAPACKE function needs it to store `superb`. (see MKL documentation)
#define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, int m, int n, dtype* ut, \
int ldut, dtype* s, dtype* v, int ldv, \
dtype* work, int lwork) { \
inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, lapack_index_t m, lapack_index_t n, \
dtype* ut, lapack_index_t ldut, dtype* s, dtype* v, \
lapack_index_t ldv, dtype* work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gesvd(matrix_layout, 'S', 'O', m, n, v, ldv, s, ut, ldut, \
v, ldv, work); \
Expand All @@ -360,11 +361,12 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// Computes the singular value decomposition of a general rectangular matrix
// using a divide and conquer method.
#define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, int m, int n, \
dtype *a, int lda, dtype *s, \
dtype *u, int ldu, \
dtype *vt, int ldvt, \
dtype *work, int lwork, int *iwork) { \
inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, lapack_index_t m, lapack_index_t n, \
dtype *a, lapack_index_t lda, dtype *s, \
dtype *u, lapack_index_t ldu, \
dtype *vt, lapack_index_t ldvt, \
dtype *work, lapack_index_t lwork, \
lapack_index_t *iwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, \
s, u, ldu, vt, ldvt); \
Expand All @@ -376,8 +378,9 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GESDD(d, double)

#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
int *ipiv, dtype *work, int lwork) { \
inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, lapack_index_t n, dtype *a, \
lapack_index_t lda, lapack_index_t *ipiv, \
dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
} \
Expand All @@ -389,10 +392,11 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \
int n, dtype *a, int lda, \
lapack_index_t n, dtype *a, lapack_index_t lda, \
dtype *wr, dtype *wi, \
dtype *vl, int ldvl, dtype *vr, int ldvr, \
dtype *work, int lwork) { \
dtype *vl, lapack_index_t ldvl, dtype *vr, \
lapack_index_t ldvr, \
dtype *work, lapack_index_t lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##geev(matrix_layout, jobvl, jobvr, \
n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \
Expand All @@ -404,10 +408,11 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_GEEV(d, double)

#define MXNET_LAPACK_CWRAP_GELSD(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, int m, int n, int nrhs, \
dtype *a, int lda, dtype *b, int ldb, \
dtype *s, dtype rcond, int *rank, \
dtype *work, int lwork, int *iwork) { \
inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, lapack_index_t m, lapack_index_t n, \
lapack_index_t nrhs, dtype *a, lapack_index_t lda, \
dtype *b, lapack_index_t ldb, dtype *s, dtype rcond, \
lapack_index_t *rank, dtype *work, lapack_index_t lwork, \
lapack_index_t *iwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gelsd(matrix_layout, m, n, nrhs, a, lda, b, ldb, \
s, rcond, rank); \
Expand Down
10 changes: 5 additions & 5 deletions src/operator/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void linalg_syevd(const Tensor<xpu, 2, DType>& A,
// This function determines the amount of workspace needed for linalg_syevd
// which is returned as number of elements of type DType.
template<typename xpu, typename DType>
int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
lapack_index_t linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
Stream<xpu> *s = 0);

Expand Down Expand Up @@ -224,13 +224,13 @@ int linalg_gesvd_workspace_query(const Tensor<xpu, 2, DType>& UT,
// don't throw error when A is non-invertible matrix.
template<typename xpu, typename DType>
void linalg_getrf(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, int>& pivot,
const Tensor<xpu, 1, lapack_index_t>& pivot,
bool check_singular,
Stream<xpu> *s = 0);

template<typename xpu, typename DType>
void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 2, int>& pivot,
const Tensor<xpu, 2, lapack_index_t>& pivot,
bool check_singular,
Stream<xpu> *s = 0);

Expand All @@ -244,7 +244,7 @@ void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
// - LU is also the output parameter (overwritten by inverse(A))
template<typename xpu, typename DType>
void linalg_getri(const Tensor<xpu, 2, DType>& LU,
const Tensor<xpu, 1, int>& pivot, \
const Tensor<xpu, 1, lapack_index_t>& pivot, \
const Tensor<xpu, 1, DType>& work,
Stream<xpu> *s = 0);

Expand Down Expand Up @@ -274,7 +274,7 @@ void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
// from LU and pivot using temp workspace, the result is stored back to LU
template<typename xpu, typename DType>
void linalg_batch_det_backward_helper(const Tensor<xpu, 3, DType>& LU,
const Tensor<xpu, 2, int>& pivot,
const Tensor<xpu, 2, lapack_index_t>& pivot,
const Tensor<xpu, 1, DType>& det,
const Tensor<xpu, 3, DType>& temp,
const DType zero_det,
Expand Down
Loading

0 comments on commit b4d19fc

Please sign in to comment.