Skip to content

Commit

Permalink
shared_external mermory add xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 committed Apr 23, 2023
1 parent b1d3ec1 commit 6f64981
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,16 @@ void Tensor::ShareExternalData(const T *data,
const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
meta);
*tensor = std::move(dtensor);
} else if (place == PlaceType::kXPU) {
phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(
const_cast<T *>(data), size, paddle::platform::XPUPlace(device_)),
meta);
*tensor = std::move(dtensor);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
"PlaceType must be one of [PlaceType::kCPU, PlaceType::kGPU, "
"PlaceType::kXPU]."));
}
}

Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/xpu/linspace_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,8 @@ void LinspaceKernel(const Context& ctx,
} // namespace phi

PD_REGISTER_KERNEL(
linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t) {}
linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}

0 comments on commit 6f64981

Please sign in to comment.