Skip to content

Commit

Permalink
[oneDNN] disable caching oneDNN primitives in matmul v2, Reduce grad …
Browse files Browse the repository at this point in the history
…and elementwise_add grad, expand_v2 (#35132)

* - grad caching disabled of matmul_v1

- compilation fix

- compilation fix

* - reduction removed

* - Matmul v2 disabled caching

* Draft of further changes

* - workaround for reducegrad

* - fixes to UT

* - fix to compilation

* - another fix

* - fix
  • Loading branch information
jczaja authored Aug 26, 2021
1 parent 8dc050d commit 31f0221
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, dy,
ctx.InputName(framework::GradVarName("Out")),
CalculateBroadcastedDims(dout, dy));
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/mkldnn/expand_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out->Resize(paddle::framework::make_ddim(out_new_dims));
out->set_format(x_format_tag);
paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(),
out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims);
dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x,
0.0f, 1.0f, x_vec_dims);

auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
Expand Down Expand Up @@ -136,8 +136,8 @@ class ExpandGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
paddle::platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc()));
} else {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dout, dx, ctx.InputName("X"), dx_vec_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dx, dx_vec_dims);

auto src_memory_p = handler.AcquireSrcMemory(dout);
auto dst_memory_p = handler.AcquireDstMemory(dx);
Expand Down
106 changes: 49 additions & 57 deletions paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,58 +83,52 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,

template <typename T>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulMKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale, const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
auto mat_dim_x = paddle::operators::math::CreateMatrixDescriptor(
x->dims(), 0, trans_x);
auto mat_dim_y = paddle::operators::math::CreateMatrixDescriptor(
y->dims(), 0, trans_y);

memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;

memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;

memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};

memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);

dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});

this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y->dims(), 0, trans_y);

memory::dim x_bs = mat_dim_x.batch_size_;
memory::dim y_bs = mat_dim_y.batch_size_;

memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;

memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};

memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M};

memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);

dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});

this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}

std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};

Expand Down Expand Up @@ -565,7 +559,7 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y, bool is_fold_init_dims_y,
Tensor* out, int execution_number) const {
Tensor* out) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Expand All @@ -583,10 +577,8 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(

float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;

MatMulMKLDNNHandler<T> handler(dev_ctx, engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out, alpha,
ctx.InputName(framework::GradVarName("Out")) +
std::to_string(execution_number));
MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
&y_combined, trans_y, out, alpha);

const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
Expand Down Expand Up @@ -645,24 +637,24 @@ void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext& ctx) const {

if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, true, true, &dout,
true, false, dx, 0);
true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
true, false, dy, 1);
true, false, dy);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &y, false, false,
&dout, true, false, dx, 0);
&dout, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, false, false,
&dout, false, true, dy, 1);
&dout, false, true, dy);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, false, true, dx, 0);
&y, false, true, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, true, true, &x,
false, true, dy, 1);
false, true, dy);
} else {
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &dout, false, false,
&y, true, false, dx, 0);
&y, true, false, dx);
this->ExecuteMatMulGrad(ctx, dev_ctx, onednn_engine, &x, true, true, &dout,
false, true, dy, 1);
false, true, dy);
}

if (dx) {
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine& engine, Tensor* x, bool trans_x,
bool is_fold_init_dims_x, Tensor* y, bool trans_y,
bool is_fold_init_dims_y, Tensor* out,
int execution_number) const;
bool is_fold_init_dims_y, Tensor* out) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
Expand Down
122 changes: 57 additions & 65 deletions paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,79 +31,72 @@ using paddle::framework::GradVarName;

template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
MatMulV2MKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
const std::string& uniq_name)
: paddle::platform::MKLDNNHandlerT<T, dnnl::matmul>(
dev_ctx, engine, cpu_place,
paddle::platform::CreateKey(dev_ctx, x_org_dims, uniq_name)) {
if (!this->isCached()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;

if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);

const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];

std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);

x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());

if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
const std::vector<int64_t>& y_org_dims, bool trans_y)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);

const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;

if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);

const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];

std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);

x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());

if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}

if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}

out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});

for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}

auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md =
memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);

this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}

std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data),
"@weights_mem_p");
to_void_cast<T>(input_data));
}
};

Expand All @@ -122,9 +115,8 @@ class MatMulV2MKLDNNKernel
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(
dev_ctx, onednn_engine, ctx.GetPlace(), x_dims, trans_x, y_dims,
trans_y, ctx.InputName("X") + std::to_string(execution_number));
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y);

const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
Expand Down Expand Up @@ -251,8 +243,8 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel<T> {
const Tensor* dx_tmp, Tensor* dx,
std::vector<int64_t> dx_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, ctx.InputName("X"), dx_dims);
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dx_tmp, dx, dx_dims);

auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
Expand Down
Loading

0 comments on commit 31f0221

Please sign in to comment.