Skip to content

Commit

Permalink
[XPU][PHI Kernels] use ctx_guard to allocate temporary buffer (#59334)
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart authored Nov 27, 2023
1 parent 32af85e commit 24f25c8
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 202 deletions.
13 changes: 4 additions & 9 deletions paddle/fluid/operators/affine_channel_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,10 @@ class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
"The reduce_sum XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
T* tmp = nullptr;
r = xpu_malloc(reinterpret_cast<void**>(&tmp), dy->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
platform::errors::External("no enough memory in xpu"));
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
T* tmp = RAII_GUARD.alloc_l3_or_gm<T>(dy->numel());
PADDLE_ENFORCE_NOT_NULL(
tmp, platform::errors::External("XPU has no enough memory"));

r = xpu::mul<T>(
dev_ctx.x_context(), dy_d, x->data<T>(), tmp, dy->numel());
Expand All @@ -166,10 +165,6 @@ class AffineChannelGradXPUKernel : public framework::OpKernel<T> {
"The reduce_sum XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
xpu_free(tmp);
}
if (dx_d) {
r = xpu::broadcast_mul(
Expand Down
33 changes: 17 additions & 16 deletions paddle/phi/kernels/funcs/adam_functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@ using float16 = dtype::float16;
#ifdef PADDLE_WITH_XPU

template <typename Context, typename T1, typename T2>
static int ConvertDataByType(
const T1* x, T2** y, int len, bool allocateFlag, const Context& dev_ctx) {
static int ConvertDataByType(const T1* x,
T2** y,
int len,
bool allocateFlag,
const Context& dev_ctx,
xpu::ctx_guard* ctx_guard) {
if (nullptr == x || nullptr == y || len <= 0)
return xpu::Error_t::INVALID_PARAM;
int r = 0;
if (allocateFlag) {
r = xpu_malloc(reinterpret_cast<void**>(y), sizeof(T2) * len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
*y = ctx_guard->alloc_l3_or_gm<T2>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(*y);
}

T1* cpu_data = reinterpret_cast<T1*>(malloc(sizeof(T1) * len));
Expand All @@ -62,13 +65,14 @@ static int ConvertDataByType(
template <typename Context, typename T>
static void GetDataPointer(const phi::DenseTensor& tensorData,
T** result,
const Context& dev_ctx) {
const Context& dev_ctx,
xpu::ctx_guard* ctx_guard) {
if (tensorData.dtype() == DataType::FLOAT16) {
const float16* real_data = tensorData.template data<float16>();
int len = tensorData.numel();

int r = ConvertDataByType<Context, float16, T>(
real_data, result, len, true, dev_ctx);
real_data, result, len, true, dev_ctx, ctx_guard);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
Expand All @@ -88,23 +92,19 @@ static void GetOutDataPointer(DenseTensor* tensorData,
template <typename Context, typename T>
static void CopyOutData(const DenseTensor& srcTensor,
phi::DenseTensor* dstTensor,
const Context& dev_ctx) {
const Context& dev_ctx,
xpu::ctx_guard* ctx_guard) {
if (dstTensor->dtype() == DataType::FLOAT16) {
const T* xpu_out_data = srcTensor.template data<T>();
float16* out_data = dev_ctx.template Alloc<float16>(dstTensor);
int len = srcTensor.numel();

int r = ConvertDataByType<Context, T, float16>(
xpu_out_data, &out_data, len, false, dev_ctx);
xpu_out_data, &out_data, len, false, dev_ctx, ctx_guard);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}

template <typename T>
static void FreeData(const phi::DenseTensor& tensorData, T* dataPtr) {
if (tensorData.dtype() == DataType::FLOAT16) xpu_free(dataPtr);
}

template <typename Context, typename T>
static void SetBetaData(const phi::DenseTensor& beta_pow,
phi::DenseTensor* beta_pow_out,
Expand All @@ -125,7 +125,8 @@ static void Scale(phi::DenseTensor* beta_pow_out,
const phi::DenseTensor& beta_pow,
T* beta_pow_ptr,
const T& beta,
const Context& dev_ctx) {
const Context& dev_ctx,
xpu::ctx_guard* ctx_guard) {
float16* beta_pow_out_p2 = dev_ctx.template Alloc<float16>(beta_pow_out);

DenseTensor xpu_beta_pow_out;
Expand All @@ -149,7 +150,7 @@ static void Scale(phi::DenseTensor* beta_pow_out,
int len = xpu_beta_pow_out.numel();

r = ConvertDataByType<Context, T, float16>(
xpu_beta_pow_out_data, &beta_pow_out_p2, len, false, dev_ctx);
xpu_beta_pow_out_data, &beta_pow_out_p2, len, false, dev_ctx, ctx_guard);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
#endif
Expand Down
54 changes: 33 additions & 21 deletions paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,23 @@ void AdamDenseParamSparseGradKernel(
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using XPUType = typename XPUTypeTrait<T>::Type;
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* param_ptr = nullptr;
funcs::GetDataPointer<Context, float>(param, &param_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
param, &param_ptr, dev_ctx, &RAII_GUARD);

float* mom1_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment1, &mom1_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
moment1, &mom1_ptr, dev_ctx, &RAII_GUARD);

float* mom2_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment2, &mom2_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
moment2, &mom2_ptr, dev_ctx, &RAII_GUARD);

float* lr_ptr = nullptr;
funcs::GetDataPointer<Context, float>(learning_rate, &lr_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
learning_rate, &lr_ptr, dev_ctx, &RAII_GUARD);

xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* beta1_pow_ptr = nullptr;
const float* beta1_const_pow_ptr = nullptr;

Expand Down Expand Up @@ -92,7 +96,8 @@ void AdamDenseParamSparseGradKernel(

} else {
if (beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta1_pow, &beta1_pow_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
beta1_pow, &beta1_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta1_const_pow_ptr = beta1_pow.template data<float>();
}
Expand Down Expand Up @@ -123,7 +128,8 @@ void AdamDenseParamSparseGradKernel(
}
} else {
if (beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta2_pow, &beta2_pow_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
beta2_pow, &beta2_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta2_const_pow_ptr = beta2_pow.template data<float>();
}
Expand Down Expand Up @@ -225,7 +231,8 @@ void AdamDenseParamSparseGradKernel(
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();

funcs::GetDataPointer<Context, float>(grad_tensor, &grad_c, dev_ctx);
funcs::GetDataPointer<Context, float>(
grad_tensor, &grad_c, dev_ctx, &RAII_GUARD);

int row_count = grad_merge.rows().size();
std::vector<int> rows(row_count);
Expand Down Expand Up @@ -267,11 +274,12 @@ void AdamDenseParamSparseGradKernel(

PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");

funcs::FreeData<float>(grad_tensor, grad_c);

funcs::CopyOutData<Context, float>(xpu_mom1_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_mom2_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_param_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(
xpu_mom1_out, moment1_out, dev_ctx, &RAII_GUARD);
funcs::CopyOutData<Context, float>(
xpu_mom2_out, moment1_out, dev_ctx, &RAII_GUARD);
funcs::CopyOutData<Context, float>(
xpu_param_out, moment1_out, dev_ctx, &RAII_GUARD);

if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
Expand All @@ -285,8 +293,12 @@ void AdamDenseParamSparseGradKernel(
float* beta1_pow_out_p1 = nullptr;

if (beta1_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx);
funcs::Scale<Context, float>(beta1_pow_out,
beta1_pow,
beta1_pow_ptr,
beta1_,
dev_ctx,
&RAII_GUARD);
} else {
const float* beta1_pow_data = beta1_pow.template data<float>();
beta1_pow_out_p1 = dev_ctx.template Alloc<float>(beta1_pow_out);
Expand All @@ -303,8 +315,12 @@ void AdamDenseParamSparseGradKernel(

float* beta2_pow_out_p1 = nullptr;
if (beta2_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx);
funcs::Scale<Context, float>(beta2_pow_out,
beta2_pow,
beta2_pow_ptr,
beta2_,
dev_ctx,
&RAII_GUARD);
} else {
const float* beta2_pow_data = beta2_pow.template data<float>();
beta2_pow_out_p1 = dev_ctx.template Alloc<float>(beta2_pow_out);
Expand All @@ -320,10 +336,6 @@ void AdamDenseParamSparseGradKernel(
}
}
}
funcs::FreeData<float>(param, param_ptr);
funcs::FreeData<float>(moment1, mom1_ptr);
funcs::FreeData<float>(moment2, mom2_ptr);
funcs::FreeData<float>(learning_rate, lr_ptr);
}
} // namespace sr
} // namespace phi
Expand Down
56 changes: 34 additions & 22 deletions paddle/phi/kernels/xpu/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,22 @@ void AdamDenseKernel(const Context& dev_ctx,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* param_ptr = nullptr;
funcs::GetDataPointer<Context, float>(param, &param_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
param, &param_ptr, dev_ctx, &RAII_GUARD);

float* mom1_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment1, &mom1_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
moment1, &mom1_ptr, dev_ctx, &RAII_GUARD);

float* mom2_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment2, &mom2_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
moment2, &mom2_ptr, dev_ctx, &RAII_GUARD);

float* lr_ptr = nullptr;
funcs::GetDataPointer<Context, float>(learning_rate, &lr_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
learning_rate, &lr_ptr, dev_ctx, &RAII_GUARD);

float* beta1_pow_ptr = nullptr;
const float* beta1_const_pow_ptr = nullptr;
Expand All @@ -68,12 +73,13 @@ void AdamDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow);
if (xpu_beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta1_pow, &beta1_pow_ptr, dev_ctx);
xpu_beta1_pow, &beta1_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta1_const_pow_ptr = xpu_beta1_pow.template data<float>();
} else {
if (beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta1_pow, &beta1_pow_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
beta1_pow, &beta1_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta1_const_pow_ptr = beta1_pow.template data<float>();
}
Expand All @@ -85,12 +91,13 @@ void AdamDenseKernel(const Context& dev_ctx,
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow);
if (xpu_beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta2_pow, &beta2_pow_ptr, dev_ctx);
xpu_beta2_pow, &beta2_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta2_const_pow_ptr = xpu_beta2_pow.template data<float>();
} else {
if (beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta2_pow, &beta2_pow_ptr, dev_ctx);
funcs::GetDataPointer<Context, float>(
beta2_pow, &beta2_pow_ptr, dev_ctx, &RAII_GUARD);
else
beta2_const_pow_ptr = beta2_pow.template data<float>();
}
Expand Down Expand Up @@ -163,7 +170,7 @@ void AdamDenseKernel(const Context& dev_ctx,
auto epsilon_ = epsilon.to<float>();

float* grad_c = nullptr;
funcs::GetDataPointer<Context, float>(grad, &grad_c, dev_ctx);
funcs::GetDataPointer<Context, float>(grad, &grad_c, dev_ctx, &RAII_GUARD);

int r = xpu::adam(
dev_ctx.x_context(),
Expand All @@ -184,11 +191,12 @@ void AdamDenseKernel(const Context& dev_ctx,

PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");

funcs::FreeData<float>(grad, grad_c);

funcs::CopyOutData<Context, float>(xpu_mom1_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_mom2_out, moment2_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_param_out, param_out, dev_ctx);
funcs::CopyOutData<Context, float>(
xpu_mom1_out, moment1_out, dev_ctx, &RAII_GUARD);
funcs::CopyOutData<Context, float>(
xpu_mom2_out, moment2_out, dev_ctx, &RAII_GUARD);
funcs::CopyOutData<Context, float>(
xpu_param_out, param_out, dev_ctx, &RAII_GUARD);

if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
Expand All @@ -202,8 +210,12 @@ void AdamDenseKernel(const Context& dev_ctx,
float* beta1_pow_out_p1 = nullptr;

if (beta1_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx);
funcs::Scale<Context, float>(beta1_pow_out,
beta1_pow,
beta1_pow_ptr,
beta1_,
dev_ctx,
&RAII_GUARD);
} else {
const float* beta1_pow_data = beta1_pow.template data<float>();
beta1_pow_out_p1 = dev_ctx.template Alloc<float>(beta1_pow_out);
Expand All @@ -219,8 +231,12 @@ void AdamDenseKernel(const Context& dev_ctx,

float* beta2_pow_out_p1 = nullptr;
if (beta2_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx);
funcs::Scale<Context, float>(beta2_pow_out,
beta2_pow,
beta2_pow_ptr,
beta2_,
dev_ctx,
&RAII_GUARD);
} else {
const float* beta2_pow_data = beta2_pow.template data<float>();
beta2_pow_out_p1 = dev_ctx.template Alloc<float>(beta2_pow_out);
Expand All @@ -235,10 +251,6 @@ void AdamDenseKernel(const Context& dev_ctx,
}
}
}
funcs::FreeData<float>(param, param_ptr);
funcs::FreeData<float>(moment1, mom1_ptr);
funcs::FreeData<float>(moment2, mom2_ptr);
funcs::FreeData<float>(learning_rate, lr_ptr);
}

template <typename T, typename Context>
Expand Down
Loading

0 comments on commit 24f25c8

Please sign in to comment.