-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
conv2d support bfloat16 #32221
conv2d support bfloat16 #32221
Changes from 6 commits
687e28b
252dbcf
747d096
8bab3d7
74bb02d
cd612c5
4069f78
7d3a4d5
c41fe74
f3ca4b8
5a3e730
f23f1d2
15f3315
394d8d4
12cc70b
62fcd51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -131,6 +131,27 @@ inline ActivationMode StringToActivationMode(const std::string& str) { | |||||||||||
template <typename T> | ||||||||||||
class CudnnDataType; | ||||||||||||
|
||||||||||||
template <> | ||||||||||||
class CudnnDataType<bfloat16> { | ||||||||||||
public: | ||||||||||||
// CUDNN_DATA_BFLOAT16 is not valid before cudnn8.1 | ||||||||||||
#if CUDNN_VERSION_MIN(8, 1, 0) | ||||||||||||
static const cudnnDataType_t type = CUDNN_DATA_BFLOAT16; | ||||||||||||
#else | ||||||||||||
static const cudnnDataType_t type = CUDNN_DATA_HALF; | ||||||||||||
#endif | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #else分支不需要吧。当cudnn版本 < 8.1时,整个class应该不被编译。所以是不是在class整体头尾分别加上#if和#endif。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为conv2d对于bfloat16需要编译成功,代码逻辑中 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cudnn8.1版本以下也不该用half类型,应该直接挂掉。另外,你加的都是 Paddle/paddle/fluid/operators/conv_op.cc Lines 187 to 191 in 79f7ba6
|
||||||||||||
using ScalingParamType = const float; | ||||||||||||
using BatchNormParamType = float; | ||||||||||||
static ScalingParamType* kOne() { | ||||||||||||
static ScalingParamType v = 1.0; | ||||||||||||
return &v; | ||||||||||||
} | ||||||||||||
static ScalingParamType* kZero() { | ||||||||||||
static ScalingParamType v = 0.0; | ||||||||||||
return &v; | ||||||||||||
} | ||||||||||||
}; | ||||||||||||
|
||||||||||||
template <> | ||||||||||||
class CudnnDataType<float16> { | ||||||||||||
public: | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,6 +167,37 @@ def test_check_grad_no_input(self): | |
globals()[cls_name] = TestConv2DCUDNNFp16 | ||
|
||
|
||
def create_test_cudnn_bf16_class(parent, grad_check=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. conv的测试不需要依赖 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要的,目前已merge最新代码,同步 |
||
@unittest.skipIf( | ||
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, | ||
"core is not compiled with CUDA and cudnn version need larger than 8.1.0" | ||
) | ||
class TestConv2DCUDNNBF16(parent): | ||
def init_kernel_type(self): | ||
self.use_cudnn = True | ||
self.dtype = np.uint16 | ||
|
||
def test_check_output(self): | ||
place = core.CUDAPlace(0) | ||
self.check_output_with_place(place, atol=1e-2) | ||
|
||
def test_check_grad_no_filter(self): | ||
place = core.CUDAPlace(0) | ||
if grad_check: | ||
self.check_grad_with_place( | ||
place, ['Input'], 'Output', no_grad_set=set(['Filter'])) | ||
|
||
def test_check_grad_no_input(self): | ||
place = core.CUDAPlace(0) | ||
if grad_check: | ||
self.check_grad_with_place( | ||
place, ['Filter'], 'Output', no_grad_set=set(['Input'])) | ||
|
||
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNBF16") | ||
TestConv2DCUDNNBF16.__name__ = cls_name | ||
globals()[cls_name] = TestConv2DCUDNNBF16 | ||
|
||
|
||
def create_test_channel_last_class(parent): | ||
class TestChannelLastCase(parent): | ||
def init_data_format(self): | ||
|
@@ -554,6 +585,15 @@ def init_group(self): | |
create_test_cudnn_fp16_class(TestWith1x1, grad_check=False) | ||
create_test_cudnn_fp16_class(TestWithInput1x1Filter1x1, grad_check=False) | ||
|
||
#----------------Conv2DCUDNN bf16---------------- | ||
|
||
create_test_cudnn_bf16_class(TestConv2DOp, grad_check=False) | ||
create_test_cudnn_bf16_class(TestWithPad, grad_check=False) | ||
create_test_cudnn_bf16_class(TestWithStride, grad_check=False) | ||
create_test_cudnn_bf16_class(TestWithGroup, grad_check=False) | ||
create_test_cudnn_bf16_class(TestWith1x1, grad_check=False) | ||
create_test_cudnn_bf16_class(TestWithInput1x1Filter1x1, grad_check=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 都不检查梯度? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 之前试参考cpu上bf16测试,重新commit代码已默认添加反向测试。 |
||
|
||
#----------------TestDepthwiseConv ----- | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个检查能不能放到一个公共的地方,比如
CudnnDataType<bfloat16>
里面?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CudnnDataType<bfloat16>
里只能做编译期检查,这里直接改为cudnn8.1以下不添加bfloat16数据类型的Kernel。