Skip to content

Commit

Permalink
Optimized indexing related ops performance
Browse files Browse the repository at this point in the history
Signed-off-by: majing <Jing1.Ma@intel.com>
  • Loading branch information
majing921201 committed Sep 19, 2024
1 parent c134991 commit 3558073
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 37 deletions.
42 changes: 22 additions & 20 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ void index_select_kernel(
int64_t dim,
const Tensor& indices,
const Tensor& dst) {
std::cout << "src " << src.sizes() << src.strides() << std::endl;
std::cout << "indices " << indices.sizes() << indices.strides() << std::endl;
at::assert_no_internal_overlap(dst);
at::assert_no_overlap(dst, src);
at::assert_no_overlap(dst, indices);
Expand Down Expand Up @@ -165,8 +167,8 @@ void index_select_kernel(
"index_select(): Source and result must have the same scalar type");

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "index_select", [&] {
TensorInfo<index_t, int64_t> index_info =
tensorInfoIfScalar(getTensorInfo<index_t, int64_t>(indices));
TensorInfo<index_t, unsigned int> index_info =
tensorInfoIfScalar(getTensorInfo<index_t, unsigned int>(indices));
index_info.collapseDims();

auto new_size = src.sizes().vec();
Expand All @@ -186,15 +188,15 @@ void index_select_kernel(
dst.scalar_type(),
"index_select_xpu",
AT_WRAP([&] {
TensorInfo<scalar_t, int64_t> dst_info =
tensorInfoIfScalar(getTensorInfo<scalar_t, int64_t>(dst));
TensorInfo<scalar_t, int64_t> src_info = tensorInfoIfScalar(
getTensorInfo<scalar_t, int64_t>(src.contiguous()));
TensorInfo<scalar_t, unsigned int> dst_info =
tensorInfoIfScalar(getTensorInfo<scalar_t, unsigned int>(dst));
TensorInfo<scalar_t, unsigned int> src_info = tensorInfoIfScalar(
getTensorInfo<scalar_t, unsigned int>(src.contiguous()));
int new_indexing_dim = src_info.collapseDims(dim);

using SrcInfo = TensorInfo<scalar_t, int64_t>;
using DstInfo = TensorInfo<scalar_t, int64_t>;
using IdxInfo = TensorInfo<index_t, int64_t>;
using SrcInfo = TensorInfo<scalar_t, unsigned int>;
using DstInfo = TensorInfo<scalar_t, unsigned int>;
using IdxInfo = TensorInfo<index_t, unsigned int>;

// Improve efficiency of generated native instructions for contiguous.
// See comm/TensorInfo.h
Expand Down Expand Up @@ -400,15 +402,15 @@ void index_add_kernel(
"index_add_xpu",
[&] {
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_xpu", [&]() {
TensorInfo<index_t, int64_t> index_info =
getTensorInfo<index_t, int64_t>(index);
TensorInfo<index_t, unsigned int> index_info =
getTensorInfo<index_t, unsigned int>(index);
index_info.collapseDims();

TensorInfo<scalar_t, int64_t> src_info =
getTensorInfo<scalar_t, int64_t>(source_);
TensorInfo<scalar_t, unsigned int> src_info =
getTensorInfo<scalar_t, unsigned int>(source_);

TensorInfo<scalar_t, int64_t> dst_info =
getTensorInfo<scalar_t, int64_t>(self_);
TensorInfo<scalar_t, unsigned int> dst_info =
getTensorInfo<scalar_t, unsigned int>(self_);
int new_indexing_dim = dst_info.collapseDims(dim);

using IdxConfig = IndexKernelConfig<
Expand Down Expand Up @@ -472,16 +474,16 @@ void index_fill_kernel(
self.scalar_type(),
"index_fill_xpu",
[&] {
TensorInfo<int64_t, int64_t> index_info =
getTensorInfo<int64_t, int64_t>(index);
TensorInfo<int64_t, unsigned int> index_info =
getTensorInfo<int64_t, unsigned int>(index);
index_info.collapseDims();

TensorInfo<scalar_t, int64_t> dst_info =
getTensorInfo<scalar_t, int64_t>(self);
TensorInfo<scalar_t, unsigned int> dst_info =
getTensorInfo<scalar_t, unsigned int>(self);
int new_indexing_dim = dst_info.collapseDims(dim);

// No used in index kernel frame for index_fill.
auto src_info = TensorInfo<scalar_t, int64_t>();
auto src_info = TensorInfo<scalar_t, unsigned int>();

using IdxConfig = IndexKernelConfig<
decltype(src_info),
Expand Down
48 changes: 37 additions & 11 deletions src/ATen/native/xpu/sycl/Indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@

using namespace at::xpu::detail;
using namespace at::xpu;

#if defined(__SYCL_DEVICE_ONLY__)
#define DPCPP_CONSTANT __attribute__((opencl_constant))
#else
#define DPCPP_CONSTANT
#endif

#define DPCPP_KER_STRING(var, str) static const DPCPP_CONSTANT char var[] = str;
#define DPCPP_KER_PRINTF sycl::ext::oneapi::experimental::printf

#define DPCPP_K_PRINT(fmt_str, ...) \
{ \
DPCPP_KER_STRING(fmt_var, fmt_str); \
DPCPP_KER_PRINTF(fmt_var, ##__VA_ARGS__); \
}
namespace at::native::xpu {

template <int N>
Expand Down Expand Up @@ -209,10 +222,10 @@ class IndexKernel {
if constexpr (TrivialOffCal) {
idx_off = idx_logical_off;
} else {
idx_off = IndexToOffset<IdxType, int64_t>::get(
idx_off = IndexToOffset<IdxType, unsigned int>::get(
idx_logical_off,
cfg_.iinfo_,
IndexToOffset<IdxType, int64_t>::NON_STRICT_CONTIGUOUS);
IndexToOffset<IdxType, unsigned int>::NON_STRICT_CONTIGUOUS);
}
glb_batch_group = id.glb_batch / cfg_.index_num_;
glb_batch_group_loc_off = cfg_.iinfo_.data[idx_off];
Expand Down Expand Up @@ -320,26 +333,26 @@ class IndexKernel {
} else {
if (cfg_.indexing_dst_) {
// index_copy, index_add, index_fill
dst_off = IndexToOffset<ValType, int64_t>::get(
dst_off = IndexToOffset<ValType, unsigned int>::get(
glb_indexing_logical_off,
cfg_.dinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
IndexToOffset<ValType, unsigned int>::NON_STRICT_CONTIGUOUS);
if (cfg_.sinfo_.data != nullptr) {
src_off = IndexToOffset<ValType, int64_t>::get(
src_off = IndexToOffset<ValType, unsigned int>::get(
glb_fixing_logical_off,
cfg_.sinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
IndexToOffset<ValType, unsigned int>::NON_STRICT_CONTIGUOUS);
}
} else {
// index_select
src_off = IndexToOffset<ValType, int64_t>::get(
src_off = IndexToOffset<ValType, unsigned int>::get(
glb_indexing_logical_off,
cfg_.sinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
dst_off = IndexToOffset<ValType, int64_t>::get(
IndexToOffset<ValType, unsigned int>::NON_STRICT_CONTIGUOUS);
dst_off = IndexToOffset<ValType, unsigned int>::get(
glb_fixing_logical_off,
cfg_.dinfo_,
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
IndexToOffset<ValType, unsigned int>::NON_STRICT_CONTIGUOUS);
}
}
cfg_.func_(
Expand Down Expand Up @@ -616,6 +629,13 @@ struct IndexKernelFunctor {
offset += index * strides_[i];
} else {
int64_t index = *(int64_t*)(index_ptrs_[i] + offsets[2]);
if (linear_idx == 960) {
DPCPP_K_PRINT(
"2, index is %d, size is, offset is %d\n",
index,
sizes_[i],
offsets[2]);
}
SYCL_KERNEL_ASSERT(
index >= -sizes_[i] && index < sizes_[i] && "index out of bounds");
if (index < 0) {
Expand Down Expand Up @@ -665,11 +685,15 @@ void index_kernel_impl(
IntArrayRef index_stride,
const func_t f) {
size_t num_indices = index_size.size();
std::cout << "num_indices = " << num_indices << std::endl;
auto numel = iter.numel();
at::detail::Array<int64_t, XPU_MAX_TENSORINFO_DIMS> sizes(0);
at::detail::Array<int64_t, XPU_MAX_TENSORINFO_DIMS> strides(0);
std::cout << "input size" << iter.tensor(1).sizes() << std::endl;
std::cout << "output size" << iter.tensor(0).sizes() << std::endl;
for (size_t i = 0; i < num_indices; i++) {
sizes[i] = index_size[i];
std::cout << sizes[i] << std::endl;
strides[i] = index_stride[i];
}

Expand All @@ -682,6 +706,8 @@ void index_kernel_impl(
at::detail::Array<index_buf_type, XPU_MAX_TENSORINFO_DIMS> index_ptrs;
for (size_t i = 0; i < num_indices; i++) {
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
std::cout << "index tensor " << iter.tensor(i + 2).sizes()
<< iter.tensor(i + 2).strides() << std::endl;
}

auto offset_calc = make_offset_calculator<3>(iter);
Expand Down
12 changes: 6 additions & 6 deletions src/comm/TensorInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ struct IndexToOffset {
}

#pragma unroll
for (int dim = XPU_MAX_TENSORINFO_DIMS - 1; dim >= 0; --dim) {
for (int dim = XPU_MAX_TENSORINFO_DIMS - 1; dim > 0; --dim) {
if (dim < info.dims) {
auto divider = at::detail::IntDivider<IndexType>(info.sizes[dim]);
auto divmod = divider.divmod(linearId);
linearId = divmod.div;
offset += divmod.mod * info.strides[dim];
IndexType curDimIndex = linearId % info.sizes[dim];
IndexType curDimOffset = curDimIndex * info.strides[dim];
offset += curDimOffset;
linearId /= info.sizes[dim];
}
}
return offset;
return offset + linearId * info.strides[0];
}
};

Expand Down

0 comments on commit 3558073

Please sign in to comment.