Skip to content

Commit

Permalink
[XPU] get data() pointer with right type;test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
newway committed Nov 1, 2023
1 parent e3e9e5c commit 46d4c6f
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions paddle/phi/kernels/xpu/index_select_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,33 @@ void IndexSelectKernel(const Context& ctx,
index_type,
phi::DataType::INT32,
phi::DataType::INT64));

auto* in_data = x.data<T>();
std::vector<int> in_shape = phi::vectorize<int>(input_dim);
int index_len = output->dims()[dim];
T* out_data = ctx.template Alloc<T>(output);
int r = 0;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int8_t* index_ptr = nullptr;
int8_t* index_ptr = nullptr; // temp xpu buffer
int byte_times = sizeof(index_type);
if (index.place() == CPUPlace()) {
index_ptr = RAII_GUARD.alloc_l3_or_gm<int8_t>(byte_times * index.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr);
const void* cpu_idx_data = nullptr;
if (index_type == phi::DataType::INT64) {
cpu_idx_data = reinterpret_cast<const void*>(index.data<int64_t>());
} else if (index_type == phi::DataType::INT32) {
cpu_idx_data = reinterpret_cast<const void*>(index.data<int>());
}
memory_utils::Copy(ctx.GetPlace(),
reinterpret_cast<void*>(const_cast<int8_t*>(index_ptr)),
reinterpret_cast<void*>(index_ptr),
CPUPlace(),
reinterpret_cast<const void*>(index.data<int>()),
cpu_idx_data,
byte_times * index.numel());
} else {
index_ptr = index.template data<int8_t>();
}
if (index_type == phi::DataType::INT64) {
const int64_t* index_data =
reinterpret_cast<const int64_t*>(const_cast<int8_t*>(index_ptr));
index_ptr ? reinterpret_cast<const int64_t*>(index_ptr)
: index.template data<int64_t>();
r = xpu::gather<T, int64_t>(ctx.x_context(),
in_data,
index_data,
Expand All @@ -71,8 +75,8 @@ void IndexSelectKernel(const Context& ctx,
index_len,
dim);
} else {
const int* index_data =
reinterpret_cast<const int*>(const_cast<int8_t*>(index_ptr));
const int* index_data = index_ptr ? reinterpret_cast<const int*>(index_ptr)
: index.template data<int>();
r = xpu::gather<T, int>(ctx.x_context(),
in_data,
index_data,
Expand Down

0 comments on commit 46d4c6f

Please sign in to comment.