diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 268ec1b6b0ac53..f6acc15dde8a21 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -157,6 +157,26 @@ inline phi::DenseTensor TransDataType(const phi::DenseTensor& tensor, } else if (tensor.place().GetType() == phi::AllocationType::GPU) { auto* dev_ctx = static_cast(pool.Get(tensor.place())); return CastDataType(*dev_ctx, tensor, dtype); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) { + phi::DenseTensor out; + out.Resize(tensor.dims()); + auto* dev_ctx = static_cast(pool.Get(tensor.place())); + auto kernel_result = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + "cast", + {phi::TransToPhiBackend(tensor.place()), + phi::DataLayout::ALL_LAYOUT, + tensor.dtype()}); + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + phi::DataType, + phi::DenseTensor*); + const auto& kernel = kernel_result.kernel; + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, tensor, dtype, &out); + return out; #endif } else { PADDLE_THROW(phi::errors::Unimplemented(