Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.59】addmm 算子FP16/BF16单测完善 #53111

Merged
merged 2 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
A,
CUDA_R_16BF,
lda,
&h_beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
Expand Down
66 changes: 65 additions & 1 deletion paddle/phi/kernels/funcs/blas/blas_impl.hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

Expand Down Expand Up @@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_gemm_ex(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
rocblas_datatype_bf16_r,
ldb,
A,
rocblas_datatype_bf16_r,
lda,
&h_beta,
C,
rocblas_datatype_bf16_r,
ldc,
C,
rocblas_datatype_bf16_r,
ldc,
rocblas_datatype_f32_r,
algo,
0,
0));
});
}

template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/addmm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
PD_REGISTER_KERNEL(addmm_grad,
GPU,
ALL_LAYOUT,
phi::AddmmGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/gpu/addmm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"

PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
PD_REGISTER_KERNEL(addmm,
GPU,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
104 changes: 94 additions & 10 deletions paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,34 @@ limitations under the License. */

#include "glog/logging.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

template <typename T>
struct CopyOrScaleFunctor {
CopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
: scale_(scale), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_scale = static_cast<MPType>(scale_);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] = static_cast<T>(mp_scale * mp_x);
}

private:
const float scale_;
const T* x_;
T* output_;
int64_t numel_;
};

template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
Expand All @@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value) {
is_float16_or_bfloat16 = true;
}

auto in_dims = input.dims();
if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]};
Expand All @@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx,
}

auto blas = funcs::GetBlas<Context, T>(dev_ctx);
auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
if (input_grad) {
dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1];
Expand All @@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2(input_grad->dims()[0], input_grad->dims()[1]);

if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum()
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(0))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(1))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else {
blas.VCOPY(total_elems, out_grad.data<T>(), input_grad->data<T>());
// The VCOPY does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.VCOPY(
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里既然不支持fp16和bf16,是不是没必要用MPType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不使用MPType编译器有错误,编译器不按程序if判断,按T类型编译,这里如果使用out_grad.data,编译器会提示不支持float16类型参数

} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
}

blas.SCAL(total_elems, beta, input_grad->data<T>());
// The SCAL does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}

if (input.dims().size() == 1) {
input_grad->Resize(input.dims());
Expand All @@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx,
total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad);
blas.SCAL(total_elems, alpha, x_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
for_range(functor);
}
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad);
blas.SCAL(total_elems, alpha, y_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
for_range(functor);
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/impl/addmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);

T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false,
false,
x_dims[0],
y_dims[1],
x_dims[1],
alpha,
t_alpha,
x.data<T>(),
x_dims[1],
y.data<T>(),
y_dims[1],
beta,
t_beta,
out->data<T>(),
y_dims[1]);
}
Expand Down
Loading