Skip to content

Commit

Permalink
add bf16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 committed Oct 23, 2023
1 parent 399581b commit 1a16786
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,15 @@ class FCMKLDNNHandler
this->dev_ctx_.GetBlob(residual_key));
if (!memory_p) {
auto dims = this->fwd_pd_->dst_desc().get_dims();
if (phi::funcs::is_int8<T_in>()) {
auto data_type = residual->dtype() == phi::DataType::INT8
? dnnl::memory::data_type::s8
: dnnl::memory::data_type::u8;
if (phi::funcs::is_int8<T_in>() || phi::funcs::is_bfloat16<T_in>()) {
constexpr bool is_int8 = phi::funcs::is_int8<T_in>();
auto data_type = dnnl::memory::data_type::bf16;
if (is_int8) {
data_type = residual->dtype() == phi::DataType::INT8
? dnnl::memory::data_type::s8
: dnnl::memory::data_type::u8;
}

auto src_0_md =
dnnl::memory::desc(dims, data_type, dnnl::memory::format_tag::ab);
auto src_1_md = dnnl::memory::desc(
Expand All @@ -315,20 +320,11 @@ class FCMKLDNNHandler
std::vector<float> src_data(phi::product(residual->dims()),
1.f / scale_data);

dnnl::memory src_0_mem;
if (residual->dtype() == phi::DataType::INT8) {
const int8_t* input_data = residual->data<int8_t>();
src_0_mem =
dnnl::memory(src_0_md,
this->dev_ctx_.GetEngine(),
phi::funcs::to_void_cast<int8_t>(input_data));
} else {
const uint8_t* input_data = residual->data<uint8_t>();
src_0_mem =
dnnl::memory(src_0_md,
this->dev_ctx_.GetEngine(),
phi::funcs::to_void_cast<uint8_t>(input_data));
}
dnnl::memory src_0_mem =
dnnl::memory(src_0_md, this->dev_ctx_.GetEngine());
void* residual_ptr = const_cast<void*>(residual->data());
src_0_mem.set_data_handle(residual_ptr);

auto src_1_mem =
dnnl::memory(src_1_md,
this->dev_ctx_.GetEngine(),
Expand Down Expand Up @@ -607,7 +603,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}

if (phi::funcs::is_int8<T_in>()) {
if (phi::funcs::is_int8<T_in>() || phi::funcs::is_bfloat16<T_in>()) {
handler.SetScalesIfNeeded(&fc_args);
}

Expand Down

0 comments on commit 1a16786

Please sign in to comment.