Skip to content

Commit

Permalink
Add bfloat16 floating-point format support based on AMP (apache#17265)
Browse files Browse the repository at this point in the history
* Add Bfloat16

* mshadow support bf16

* rebase bf16 mkldnn1.0

* support bf16 gemm

* resolve fp32 ip bwd bug

* add other bf16 ops

* change func name from fp16 to lp16 (low precision 16), to include bf16

* add amp_cast bf16 support for ndarray

* fix executor copy_params

* add test case for bf16

* remove numpy dtype hook for bf16

* add bf16 type support

* rebase to mxnet master

* add single conv test

* fix symbolic inference

* add dtype check when copy

* add single conv and bn test

* skip fp16 amp_cast test in cpu

* Fix resnet50 first convolution

* Skip first convolution for bfloat16

* support bf16 fallback compute

* recover origin test

* add some bf16 unittests

* fix bf16 bn test, enhance assert_almost_equal_with_err

* using assert_almost_equal_with_err for fallback bn test

* add relu6 bf16 support

* fix lint

* fix subgraph conv with data=0

* mkldnn doesn't support 0 dim tensor

* rm dtype check when copy

* using bf16 tvm

* rm bf16 mnist demo

* use official tvm

* change function name; fix lint error

* fix clang check error:conditional expression is ambiguous; 'float' can be converted to 'mshadow::bfloat::bf16_t' and vice versa

* nvcc compiler build pass

* fix gpu amp cast symbol error

* fix mnist training error

* fix cpp test: Engine.VarVersion error

* workaround cpp failed test mkldnn fc bwd

* to fix mkldnn test_mkldnn_ndarray_slice error

* 1. move some code from to np_broadcast_reduce_op_value.cc to np_broadcast_reduce_op_value_part2.cc to pass Win CPU/GPU build (fatal error C1002: compiler is out of heap space in pass 2)
2. rm debug code

* use official dlpack

* rename np_broadcast_reduce_op_value_part2.cc and add some description

* 1. update dlpack url in .gitmodule
2. disable mkldnn fc bwd

* fix remaining NodePtr due to tvm update

* mv some code from mxnet_op.h to mxnet_op_kernel_assign.h to avoid WIN compiler error 'fatal error C1002: compiler is out of heap space in pass 2'

* fix WIN CPU build fail:compiler is out of heap space in pass 2

* fix WIN build fail

* fix lint

* add print for test bf16_concat

* fix bf16 test fail

* disable bf16 concat test

* tmp skip to root cause edge test halt

* fix bf16_bn test error

* enable test_bulk

* tmp rm bf16 to locate edge error

* Revert "tmp rm bf16 to locate edge error"

This reverts commit 7360246.

* add Apache license header

* trigger CI

* add robust for test bf16 bn

Co-authored-by: Zhennan Qin <zhennan.qin@intel.com>
Co-authored-by: YixinBao <yixin.bao@intel.com>
Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com>
Co-authored-by: Wuxun Zhang <wuxun.zhang@intel.com>
  • Loading branch information
5 people authored and Ubuntu committed Feb 19, 2020
1 parent c00bf36 commit ec3cb6a
Show file tree
Hide file tree
Showing 62 changed files with 2,912 additions and 498 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
161 changes: 160 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,32 @@ extern "C" {

#include "./half.h"
#include "./half2.h"
#include "./bfloat.h"
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
return float(a) OP float(b); /* NOLINT(*) */ \
} \
MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
return float(a) OP float(b); /* NOLINT(*) */ \
}

/*! \brief overloaded + operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, +)
/*! \brief overloaded - operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, -)
/*! \brief overloaded * operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, *)
/*! \brief overloaded / operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, /)
/*! \brief overloaded > operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, >)
/*! \brief overloaded < operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, <)
/*! \brief overloaded >= operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, >=)
/*! \brief overloaded <= operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, <=)

#include "./logging.h"
/*! \brief namespace for mshadow */
namespace mshadow {
Expand Down Expand Up @@ -312,6 +338,11 @@ enum TypeFlag {
kInt8 = 5,
kInt64 = 6,
kBool = 7,
kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
kBfloat16 = 12
};

template<typename DType>
Expand Down Expand Up @@ -365,6 +396,11 @@ struct DataType<half::half2_t> {
static const int kLanes = 2;
};
template<>
struct DataType<bfloat::bf16_t> {
static const int kFlag = kBfloat16;
static const int kLanes = 1;
};
template<>
struct DataType<uint8_t> {
static const int kFlag = kUint8;
static const int kLanes = 1;
Expand Down Expand Up @@ -688,6 +724,11 @@ template<>
MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
return MSHADOW_HALF_MIN;
}
/*! \brief minimum value of bf16 */
template<>
MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) {
return MSHADOW_BF16_MIN;
}
/*! \brief minimum value of uint8_t */
template<>
MSHADOW_XINLINE uint8_t MinValue<uint8_t>(void) {
Expand Down Expand Up @@ -765,6 +806,11 @@ template<>
MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
return MSHADOW_HALF_MAX;
}
/*! \brief maximum value of bf16 */
template<>
MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) {
return MSHADOW_BF16_MAX;
}
/*! \brief maximum value of uint8_t */
template<>
MSHADOW_XINLINE uint8_t MaxValue<uint8_t>(void) {
Expand Down Expand Up @@ -998,6 +1044,7 @@ struct minimum {
};
} // namespace red

#ifndef __NVCC__
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
Expand All @@ -1018,6 +1065,12 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
Expand Down Expand Up @@ -1045,6 +1098,55 @@ struct minimum {
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#else
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#endif

#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
switch (type) { \
Expand Down Expand Up @@ -1147,6 +1249,7 @@ struct minimum {
LOG(FATAL) << "Unknown type enum " << type; \
}

#ifndef __NVCC__
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
switch (type$) { \
case mshadow::kFloat32: \
Expand All @@ -1170,6 +1273,13 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
Expand All @@ -1189,7 +1299,50 @@ struct minimum {
default: \
LOG(FATAL) << "Unknown type enum " << type$; \
}

#else
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
switch (type$) { \
case mshadow::kFloat32: \
{ \
typedef float DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType$; \
typedef double DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32";\
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64";\
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type$; \
}
#endif
#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
switch (layout) { \
case mshadow::kNCHW: \
Expand Down Expand Up @@ -1256,6 +1409,12 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
Expand Down
Loading

0 comments on commit ec3cb6a

Please sign in to comment.