Skip to content

Commit

Permalink
[CustomDevice] add data transform support
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Aug 24, 2023
1 parent 1c0db09 commit bb73324
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor,
} else if (tensor.place().GetType() == phi::AllocationType::GPU) {
auto* dev_ctx = static_cast<phi::GPUContext*>(pool.Get(tensor.place()));
return CastDataType(*dev_ctx, tensor, dtype);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) {
phi::DenseTensor out;
out.Resize(tensor.dims());
auto* dev_ctx = static_cast<phi::CustomDevice*>(pool.Get(tensor.place()));
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"cast",
{phi::TransToPhiBackend(tensor.place()),
phi::DataLayout::ALL_LAYOUT,
tensor.dtype()});
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
phi::DataType,
phi::DenseTensor*);
const auto& kernel = kernel_result.kernel;
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, tensor, dtype, &out);
return out;
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down

0 comments on commit bb73324

Please sign in to comment.