Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused_linear_param_grad_add_kernel #51805

Merged
merged 11 commits into from
Mar 22, 2023
148 changes: 85 additions & 63 deletions paddle/fluid/operators/fused/fused_gemm_epilogue_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ static cublasLtEpilogue_t GetEpilogueGradType(
}
}

template <typename T, bool TransX, bool TransY>
template <typename T, typename DXT, typename DYT, bool TransX, bool TransY>
void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
Expand All @@ -421,8 +421,12 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto) {
bool use_addto_dx,
bool use_addto_dy) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
static_assert(std::is_same<DXT, T>::value || std::is_same<DXT, MT>::value);
static_assert(std::is_same<DYT, T>::value || std::is_same<DYT, MT>::value);

using Trait = FusedGEMMGradTrait<TransX, TransY>;

cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType<T>();
Expand All @@ -440,8 +444,8 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
cudaStream_t stream = dev_ctx.stream();

MT alpha = static_cast<MT>(1.0);
MT beta_dx = use_addto ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
MT beta_dy = static_cast<MT>(0.0);
MT beta_dx = use_addto_dx ? static_cast<MT>(1.0) : static_cast<MT>(0.0);
MT beta_dy = use_addto_dy ? static_cast<MT>(1.0) : static_cast<MT>(0.0);

cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr;
Expand Down Expand Up @@ -508,7 +512,11 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, mat_type, x_col, x_row, x_col));
&dx_desc,
phi::backends::gpu::ToCudaDataType<DXT>(),
x_col,
x_row,
x_col));

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
Expand Down Expand Up @@ -556,7 +564,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));

auto* dx_data = dev_ctx.Alloc<T>(dx, dx->numel() * sizeof(T));
auto* dx_data = dev_ctx.Alloc<DXT>(dx, dx->numel() * sizeof(DXT));
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
const auto* a_data = kXGradAIsDZ ? dout_data : y_data;
Expand Down Expand Up @@ -627,7 +635,11 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, mat_type, y_col, y_row, y_col));
&dy_desc,
phi::backends::gpu::ToCudaDataType<DYT>(),
y_col,
y_row,
y_col));

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
Expand Down Expand Up @@ -664,7 +676,8 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
sizeof(epiloque_func_for_dy)));

if (dbias) {
auto* dbias_data = dev_ctx.Alloc<T>(dbias, dbias->numel() * sizeof(T));
auto* dbias_data =
dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
Expand All @@ -677,7 +690,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
auto* dy_data = dev_ctx.Alloc<T>(dy, dy->numel() * sizeof(T));
auto* dy_data = dev_ctx.Alloc<DYT>(dy, dy->numel() * sizeof(DYT));
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
const auto* a_data = kYGradAIsDZ ? dout_data : x_data;
Expand Down Expand Up @@ -718,7 +731,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
}
}

template <typename T>
template <typename T, typename DXT = T, typename DYT = T>
void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* dout,
const phi::DenseTensor* x,
Expand All @@ -733,70 +746,79 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
phi::DenseTensor* dx,
phi::DenseTensor* dy,
phi::DenseTensor* dbias,
bool use_addto = false) {
bool use_addto_dx = false,
bool use_addto_dy = false) {
VLOG(10) << "M=" << M << ", K=" << K << ", N=" << N << ", trans_x=" << trans_x
<< ", trans_y=" << trans_y
<< ", activation_grad=" << activation_grad;

if (trans_x) {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, true, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, true>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, true, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, false>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
}
} else {
if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, false, true>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, true>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else {
ComputeFusedGemmEpilogueBackwardImpl<T, false, false>(dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto);
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, false>(
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,16 @@
func : frame
backward : frame_grad

- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout

- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
Expand Down
60 changes: 60 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"

namespace phi {

std::vector<DDim> GetMetaTensorsDim(
Expand Down Expand Up @@ -1229,6 +1230,65 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
sequencenum->set_dtype(DataType::FLOAT32);
}

void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
MetaTensor* dweight_out,
MetaTensor* dbias_out) {
const auto dtype = dout.dtype();
PADDLE_ENFORCE_EQ(
x.dtype(),
dtype,
phi::errors::InvalidArgument(
"The data type of Input(x) and Input(dout) must be the same."));

const auto& x_dims = x.dims();
const auto& dout_dims = dout.dims();
int rank = dout_dims.size();
PADDLE_ENFORCE_EQ(
x_dims.size(),
rank,
phi::errors::InvalidArgument(
"The shape of Input(x) and Input(dout) do not match: %s vs %s.",
x_dims,
dout_dims));
for (int i = 0; i + 1 < x_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i],
dout_dims[i],
phi::errors::InvalidArgument(
"The shape of Input(x) and Input(dout) do not match: %s vs %s.",
x_dims,
dout_dims));
}

const phi::DDim& weight_dims = {x_dims[rank - 1], dout_dims[rank - 1]};
if (dweight) {
PADDLE_ENFORCE_EQ(
weight_dims,
dweight.dims(),
phi::errors::InvalidArgument(
"The shape of input(dweight) does not match the other inputs."));
}

const auto mp_dtype =
(dtype == DataType::FLOAT16 || dtype == DataType::BFLOAT16)
? DataType::FLOAT32
: dtype;

if (dbias_out) {
dbias_out->set_dims({weight_dims[1]});
dbias_out->set_dtype(multi_precision ? mp_dtype : dtype);
}

if (dweight_out) {
dweight_out->set_dims(weight_dims);
dweight_out->set_dtype(multi_precision ? mp_dtype : dtype);
}
}

void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ void EditDistanceInferMeta(const MetaTensor& hyps,
MetaTensor* sequencenum,
MetaTensor* out);

void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
const MetaTensor& dout,
const MetaTensor& dweight,
const MetaTensor& dbias,
bool multi_precision,
MetaTensor* dweight_out,
MetaTensor* dbias_out);

void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
file(GLOB kernel_impl_h "impl/*.h" "selected_rows/impl/*.h")
file(GLOB kernel_primitive_h "primitive/*.h")

# fusion ops would be included here
file(
GLOB
kernel_cu
Expand Down
33 changes: 33 additions & 0 deletions paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void FusedLinearParamGradAdd(const Context &ctx,
const DenseTensor &x,
const DenseTensor &dout,
const paddle::optional<DenseTensor> &dweight,
const paddle::optional<DenseTensor> &dbias,
bool multi_precision,
DenseTensor *dweight_out,
DenseTensor *dbias_out);

} // namespace phi
Loading