Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Migrate mul_grad kernel #48061

Merged
merged 38 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b2c89e6
cleanup unused code
Silv3S Nov 8, 2022
1be88bc
unify is_int8 is_bfloat16
Silv3S Nov 8, 2022
4f82616
Simplify matmul_v2 FWD kernel
Silv3S Nov 8, 2022
f5375fd
remove RunKernel methods
Silv3S Nov 8, 2022
9c927fa
remove import namespace
Silv3S Nov 8, 2022
6dd70a1
remove headers
Silv3S Nov 8, 2022
f763164
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 8, 2022
f613c3f
Merge branch 'PaddlePaddle:develop' into mkldnn_cleanup
Silv3S Nov 9, 2022
02392d3
clean fluid/phi cross imports
Silv3S Nov 9, 2022
a3c0c61
remove fluid axpy_handler
Silv3S Nov 9, 2022
fbf1605
delete fluid methods
Silv3S Nov 9, 2022
e8dcc47
Merge branch 'mkldnn_cleanup' into axpy_fluid_phi
Silv3S Nov 9, 2022
cc3b784
activations
Silv3S Nov 9, 2022
0994e24
OneDNNMemDesc
Silv3S Nov 9, 2022
6dbfddf
MKLDNNFormatForSize
Silv3S Nov 9, 2022
de56962
MatchShapeToLayout
Silv3S Nov 9, 2022
1bc636b
MKLDNNMemoryFormat
Silv3S Nov 9, 2022
387a16b
MKLDNNFormat
Silv3S Nov 9, 2022
7536353
ReorderMKLDNNHandler
Silv3S Nov 9, 2022
38427c2
to_void_cast
Silv3S Nov 9, 2022
6ff7998
Merge branch 'axpy_fluid_phi' into mkldnn_cleanup
Silv3S Nov 10, 2022
206afcb
review suggestions
Silv3S Nov 10, 2022
ee87e9c
interpolate
Silv3S Nov 10, 2022
6ca8717
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 10, 2022
0c9ca31
remove fluid depedency
Silv3S Nov 14, 2022
d0d94d4
Merge branch 'develop' into mkldnn_cleanup
Silv3S Nov 14, 2022
4780111
init
Silv3S Nov 14, 2022
17684e0
ExecuteMatMulV2
Silv3S Nov 14, 2022
efb3932
rm fluid kernel
Silv3S Nov 15, 2022
2c968e7
Merge branch 'develop' into phi_matmul_grad_kernel
Silv3S Nov 15, 2022
2e531cb
matmul_grad
Silv3S Nov 15, 2022
650d136
remove mutable_data
Silv3S Nov 15, 2022
faea539
Merge branch 'PaddlePaddle:develop' into phi_matmul_grad_kernel
Silv3S Nov 15, 2022
8fe8e46
Merge branch 'PaddlePaddle:develop' into phi_matmul_grad_kernel
Silv3S Nov 16, 2022
37a7627
mul_grad
Silv3S Nov 16, 2022
1d39b65
Merge branch 'develop' into phi_mul_grad_kernel
Silv3S Nov 18, 2022
d954897
Merge branch 'develop' into phi_mul_grad_kernel
Silv3S Nov 18, 2022
5d46bf4
Merge branch 'PaddlePaddle:develop' into phi_mul_grad_kernel
Silv3S Nov 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 0 additions & 83 deletions paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,83 +489,6 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
}
};

template <typename XT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }

private:
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<LoDTensor>("X");
const auto *y = ctx.Input<LoDTensor>("Y");
const auto *dout =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));

auto *dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<LoDTensor>(framework::GradVarName("Y"));

int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");

const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor &>(*x);
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor &>(*y);

Tensor dout_matrix = *dout;
dout_matrix.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});

// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);

x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];

y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];

dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];

if (dx != nullptr) {
dx->set_lod(x->lod());
this->ExecuteMatMul(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
&dout_matrix,
dout_dims,
false,
&y_matrix,
y_dims,
true,
static_cast<Tensor *>(dx));
}
if (dy != nullptr) {
dy->set_lod(y->lod());
this->ExecuteMatMul(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
&x_matrix,
x_dims,
true,
&dout_matrix,
dout_dims,
false,
static_cast<Tensor *>(dy));
}
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -578,9 +501,3 @@ REGISTER_OP_KERNEL(mul,
ops::MulMKLDNNINT8Kernel<int8_t>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float>);

REGISTER_OP_KERNEL(mul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float>);
41 changes: 41 additions & 0 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -1912,6 +1912,47 @@ class MatmulOneDNNHandler
}
};

template <typename T>
static void ExecuteMul(const OneDNNContext& dev_ctx,
Silv3S marked this conversation as resolved.
Show resolved Hide resolved
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
static const std::vector<int64_t> vec_placeholder;
MatmulOneDNNHandler<T, T, T> handler(dev_ctx,
x_dims,
y_dims,
trans_x,
trans_y,
vec_placeholder,
vec_placeholder,
false);

const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);

auto matmul_p = handler.AcquireForwardPrimitive();

std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};

auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();

// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}

template <typename T, typename T_out>
void ExecuteMatmul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
Expand Down
50 changes: 50 additions & 0 deletions paddle/phi/kernels/onednn/matmul_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,49 @@ void MatmulGradKernel(const Context &dev_ctx,
dy->Resize(y.dims());
}

template <typename T, typename Context>
void MatmulWithFlattenGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &out_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor *x_grad,
DenseTensor *y_grad) {
const DenseTensor reshaped_y =
paddle::framework::ReshapeToMatrix(y, y_num_col_dims);
const DenseTensor reshaped_x =
paddle::framework::ReshapeToMatrix(x, x_num_col_dims);
const DenseTensor x_matrix = x.dims().size() > 2 ? reshaped_x : x;
const DenseTensor y_matrix = y.dims().size() > 2 ? reshaped_y : y;

DenseTensor dout_matrix = out_grad;
dout_matrix.Resize({flatten_to_2d(x.dims(), x_num_col_dims)[0],
flatten_to_2d(y.dims(), y_num_col_dims)[1]});

// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);
x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];
y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];
dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];

if (x_grad != nullptr) {
x_grad->set_lod(x.lod());
funcs::ExecuteMul<T>(
dev_ctx, dout_matrix, y_matrix, dout_dims, y_dims, false, true, x_grad);
}
if (y_grad != nullptr) {
y_grad->set_lod(y.lod());
funcs::ExecuteMul<T>(
dev_ctx, x_matrix, dout_matrix, x_dims, dout_dims, true, false, y_grad);
}
}

} // namespace phi

PD_REGISTER_KERNEL(matmul_grad,
Expand All @@ -161,3 +204,10 @@ PD_REGISTER_KERNEL(matmul_grad,
phi::MatmulGradKernel,
float,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(matmul_with_flatten_grad,
OneDNN,
ONEDNN,
phi::MatmulWithFlattenGradKernel,
float,
phi::dtype::bfloat16) {}