Skip to content

Commit

Permalink
Add addmm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Apr 26, 2023
1 parent ed45ecc commit 688b0d9
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 22 deletions.
68 changes: 68 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,74 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
A,
CUDA_R_16BF,
lda,
&h_beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
Expand Down
64 changes: 64 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,70 @@ inline void Blas<phi::GPUContext>::GEMM(bool transA,
});
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(bool transA,
bool transB,
int M,
int N,
int K,
phi::dtype::bfloat16 alpha,
const phi::dtype::bfloat16 *A,
int lda,
const phi::dtype::bfloat16 *B,
int ldb,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 *C,
int ldc UNUSED) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"rocblas fp16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;

context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::rocblas_gemm_ex(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
rocblas_datatype_bf16_r,
ldb,
A,
rocblas_datatype_bf16_r,
lda,
&h_beta,
C,
rocblas_datatype_bf16_r,
N,
C,
rocblas_datatype_bf16_r,
N,
rocblas_datatype_f32_r,
algo,
0,
0));
});
}

template <>
template <typename T>
void Blas<phi::GPUContext>::AXPY(int n, T alpha, const T *x, T *y) const {
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/addmm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {}
PD_REGISTER_KERNEL(addmm_grad,
GPU,
ALL_LAYOUT,
phi::AddmmGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
9 changes: 8 additions & 1 deletion paddle/phi/kernels/gpu/addmm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/addmm_kernel_impl.h"

PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {}
PD_REGISTER_KERNEL(addmm,
GPU,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
104 changes: 94 additions & 10 deletions paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,34 @@ limitations under the License. */

#include "glog/logging.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/addmm_grad_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/for_range.h"

namespace phi {

template <typename T>
struct CopyOrScaleFunctor {
CopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
: scale_(scale), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_scale = static_cast<MPType>(scale_);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] = static_cast<T>(mp_scale * mp_x);
}

private:
const float scale_;
const T* x_;
T* output_;
int64_t numel_;
};

template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
Expand All @@ -45,6 +66,13 @@ void AddmmGradKernel(const Context& dev_ctx,
DenseTensor* input_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
bool is_float16_or_bfloat16 = false;
if (std::is_same<T, phi::dtype::float16>::value ||
std::is_same<T, phi::dtype::bfloat16>::value) {
is_float16_or_bfloat16 = true;
}

auto in_dims = input.dims();
if (input.dims().size() == 1) {
in_dims = {1, input.dims()[0]};
Expand All @@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx,
}

auto blas = funcs::GetBlas<Context, T>(dev_ctx);
auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
if (input_grad) {
dev_ctx.template Alloc<T>(input_grad);
total_elems = in_dims[0] * in_dims[1];
Expand All @@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx,
Array2(input_grad->dims()[0], input_grad->dims()[1]);

if (row_compress && col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum().eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum()
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (row_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(0))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else if (col_compress) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
if (!is_float16_or_bfloat16) {
eigen_dinput.device(place) =
eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
} else {
eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
.sum(Array1(1))
.eval()
.reshape(eigen_dinput_shape)
.template cast<T>();
}
} else {
blas.VCOPY(total_elems, out_grad.data<T>(), input_grad->data<T>());
// The VCOPY does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.VCOPY(
total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}
}

blas.SCAL(total_elems, beta, input_grad->data<T>());
// The SCAL does not support the float16, bfloat16
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
for_range(functor);
}

if (input.dims().size() == 1) {
input_grad->Resize(input.dims());
Expand All @@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx,
total_elems = x.dims()[0] * x.dims()[1];
// x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N
blas.MatMul(out_grad, false, y, true, x_grad);
blas.SCAL(total_elems, alpha, x_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
for_range(functor);
}
}
if (y_grad) {
dev_ctx.template Alloc<T>(y_grad);
total_elems = x.dims()[1] * y.dims()[1];
// y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K
blas.MatMul(x, true, out_grad, false, y_grad);
blas.SCAL(total_elems, alpha, y_grad->data<T>());
if (!is_float16_or_bfloat16) {
mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
} else {
phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
CopyOrScaleFunctor<T> functor(
alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
for_range(functor);
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/impl/addmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);

T t_alpha = static_cast<T>(alpha);
T t_beta = static_cast<T>(beta);
blas.GEMM(false,
false,
x_dims[0],
y_dims[1],
x_dims[1],
alpha,
t_alpha,
x.data<T>(),
x_dims[1],
y.data<T>(),
y_dims[1],
beta,
t_beta,
out->data<T>(),
y_dims[1]);
}
Expand Down
Loading

0 comments on commit 688b0d9

Please sign in to comment.