Skip to content

Commit

Permalink
Fix nullptr to TestFuseGemmEpilogueReluBWDFP* (#48997)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 authored Dec 14, 2022
1 parent 1a32448 commit e61df28
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
5 changes: 2 additions & 3 deletions paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel {
}

ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
// Note (Ming Huang): Reserve space of relu is a bit-mask,
// which cannot pass nan_and_inf checking if shape is set.
if (activation == "gelu" && ctx->HasOutput("ReserveSpace")) {

if (ctx->HasOutput("ReserveSpace")) {
ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims));
}
}
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,21 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
sizeof(bias_data)));

if (enable_auxiliary && activation != "none") {
size_t reserve_space_size = 0;
// Note (Ming Huang): The initialization of ReseveSpace is happened in the
// dev_ctx.Alloc. Therefore, we set real date type up here.
if (activation == "relu") {
// Count in bits.
reserve_space_size = phi::product(out->dims()) / 8;
paddle::experimental::DataType rs_type =
paddle::experimental::DataType::BOOL;
size_t reserve_space_size =
phi::product(reserve_space->dims()) * SizeOf(rs_type);
dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size);
} else {
reserve_space_size = phi::product(out->dims()) * sizeof(T);
size_t reserve_space_size =
phi::product(reserve_space->dims()) * sizeof(T);
dev_ctx.Alloc<T>(reserve_space, reserve_space_size);
}
dev_ctx.Alloc(reserve_space, out->type(), reserve_space_size);
void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());

void* aux_data = reserve_space->data();

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
Expand Down Expand Up @@ -184,7 +190,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
stream,
workspace->ptr(),
workspace_size);

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
operation_desc,
Expand Down Expand Up @@ -478,7 +483,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
sizeof(epiloque_func_for_dx)));

if (activation_grad != "none") {
auto* aux_data = reserve_space->data<T>();
auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
Expand Down

0 comments on commit e61df28

Please sign in to comment.