diff --git a/paddle/phi/kernels/xpu/index_select_kernel.cc b/paddle/phi/kernels/xpu/index_select_kernel.cc index e3eb974fe172e2..200286804a8956 100644 --- a/paddle/phi/kernels/xpu/index_select_kernel.cc +++ b/paddle/phi/kernels/xpu/index_select_kernel.cc @@ -40,29 +40,33 @@ void IndexSelectKernel(const Context& ctx, index_type, phi::DataType::INT32, phi::DataType::INT64)); - auto* in_data = x.data(); std::vector in_shape = phi::vectorize(input_dim); int index_len = output->dims()[dim]; T* out_data = ctx.template Alloc(output); int r = 0; xpu::ctx_guard RAII_GUARD(ctx.x_context()); - const int8_t* index_ptr = nullptr; + int8_t* index_ptr = nullptr; // temp xpu buffer int byte_times = sizeof(index_type); if (index.place() == CPUPlace()) { index_ptr = RAII_GUARD.alloc_l3_or_gm(byte_times * index.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(index_ptr); + const void* cpu_idx_data = nullptr; + if (index_type == phi::DataType::INT64) { + cpu_idx_data = reinterpret_cast(index.data()); + } else if (index_type == phi::DataType::INT32) { + cpu_idx_data = reinterpret_cast(index.data()); + } memory_utils::Copy(ctx.GetPlace(), - reinterpret_cast(const_cast(index_ptr)), + reinterpret_cast(index_ptr), CPUPlace(), - reinterpret_cast(index.data()), + cpu_idx_data, byte_times * index.numel()); - } else { - index_ptr = index.template data(); } if (index_type == phi::DataType::INT64) { const int64_t* index_data = - reinterpret_cast(const_cast(index_ptr)); + index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data, @@ -71,8 +75,8 @@ void IndexSelectKernel(const Context& ctx, index_len, dim); } else { - const int* index_data = - reinterpret_cast(const_cast(index_ptr)); + const int* index_data = index_ptr ? reinterpret_cast(index_ptr) + : index.template data(); r = xpu::gather(ctx.x_context(), in_data, index_data,