Skip to content

Commit

Permalink
rename kernel functor
Browse files Browse the repository at this point in the history
  • Loading branch information
hjhee committed Aug 30, 2024
1 parent b7228a8 commit 5e30eaf
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions src/ATen/native/xpu/sycl/TensorModeKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ void mode_fused_impl(
}

template <typename scalar_t>
struct ModeXpuKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
struct ModeKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<1> item) const {
mode_impl(
problem_values_ptr_,
Expand All @@ -545,7 +545,13 @@ struct ModeXpuKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
problem_upper_limit_,
item);
}
ModeXpuKernelFunctor(

void sycl_ker_config_convention(sycl::handler& cgh) {
// SLM(group size) is used for adjecent element comparing
slm_ = sycl_local_acc_t<scalar_t, 1>(group_size_, cgh);
}

ModeKernelFunctor(
scalar_t* problem_values_ptr,
int64_t* problem_indices_ptr,
TensorInfo<scalar_t, int64_t> values_info,
Expand All @@ -569,28 +575,24 @@ struct ModeXpuKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
group_size_(group_size),
problem_upper_limit_(problem_upper_limit) {}

void sycl_ker_config_convention(sycl::handler& cgh) {
// SLM(group size) is used for adjecent element comparing
slm_ = sycl_local_acc_t<scalar_t, 1>(group_size_, cgh);
}

private:
scalar_t* problem_values_ptr_;
int64_t* problem_indices_ptr_;
TensorInfo<scalar_t, int64_t> values_info_;
TensorInfo<int64_t, int64_t> indices_info_;
sycl_local_acc_t<scalar_t, 1> slm_;
int64_t* scratch_status_ptr_;
int64_t* scratch_value_ptr_;
int64_t problem_time_;
int64_t problem_size_;
int64_t group_number_;
int64_t group_size_;
int64_t problem_upper_limit_;

sycl_local_acc_t<scalar_t, 1> slm_;
};

template <typename scalar_t>
struct ModeXpuKernelFunctor2 : public __SYCL_KER_CONFIG_CONVENTION__ {
struct ModeFusedKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<1> item) const {
mode_fused_impl(
problem_values_ptr_,
Expand All @@ -613,7 +615,21 @@ struct ModeXpuKernelFunctor2 : public __SYCL_KER_CONFIG_CONVENTION__ {
group_size_,
item);
}
ModeXpuKernelFunctor2(

void sycl_ker_config_convention(sycl::handler& cgh) {
// SLM used for record status for mode
slm_helper_ = sycl_local_acc_t<ModeOpHelper, 1>(group_size_, cgh);

// SLM used for store value and its associated indice
slm_value_indice_ =
sycl_local_acc_t<ModeOpValueIndex<scalar_t>, 1>(group_size_, cgh);

// SLM used for sort
sort_scratch_ =
sycl_local_acc_t<std::byte, 1>(sort_scratch_memory_size_, cgh);
}

ModeFusedKernelFunctor(
scalar_t* problem_values_ptr,
TensorInfo<scalar_t, int64_t> values_info,
TensorInfo<int64_t, int64_t> indices_info,
Expand All @@ -631,31 +647,19 @@ struct ModeXpuKernelFunctor2 : public __SYCL_KER_CONFIG_CONVENTION__ {
group_number_(group_number),
group_size_(group_size) {}

void sycl_ker_config_convention(sycl::handler& cgh) {
// SLM used for record status for mode
slm_helper_ = sycl_local_acc_t<ModeOpHelper, 1>(group_size_, cgh);

// SLM used for store value and its associated indice
slm_value_indice_ =
sycl_local_acc_t<ModeOpValueIndex<scalar_t>, 1>(group_size_, cgh);

// SLM used for sort
sort_scratch_ =
sycl_local_acc_t<std::byte, 1>(sort_scratch_memory_size_, cgh);
}

private:
scalar_t* problem_values_ptr_;
TensorInfo<scalar_t, int64_t> values_info_;
TensorInfo<int64_t, int64_t> indices_info_;
sycl_local_acc_t<ModeOpHelper, 1> slm_helper_;
sycl_local_acc_t<ModeOpValueIndex<scalar_t>, 1> slm_value_indice_;
sycl_local_acc_t<std::byte, 1> sort_scratch_;
int64_t sort_scratch_memory_size_;
int64_t problem_time_;
int64_t problem_size_;
int64_t group_number_;
int64_t group_size_;

sycl_local_acc_t<ModeOpHelper, 1> slm_helper_;
sycl_local_acc_t<ModeOpValueIndex<scalar_t>, 1> slm_value_indice_;
sycl_local_acc_t<std::byte, 1> sort_scratch_;
};

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
Expand Down Expand Up @@ -764,7 +768,7 @@ void mode_kernel_impl(
auto problem_indices_ptr = problem_indices.data_ptr<int64_t>();
auto scratch_status_ptr = scratch_status_tensor.data_ptr<int64_t>();
auto scratch_value_ptr = scratch_value_tensor.data_ptr<int64_t>();
ModeXpuKernelFunctor<scalar_t> kfn(
ModeKernelFunctor<scalar_t> kfn(
problem_values_ptr,
problem_indices_ptr,
values_info,
Expand Down Expand Up @@ -793,7 +797,7 @@ void mode_kernel_impl(
auto indices_info = getTensorInfo<int64_t, int64_t>(indices_transposed);

auto problem_values_ptr = contiguous.data_ptr<scalar_t>();
ModeXpuKernelFunctor2<scalar_t> kfn(
ModeFusedKernelFunctor<scalar_t> kfn(
problem_values_ptr,
values_info,
indices_info,
Expand Down

0 comments on commit 5e30eaf

Please sign in to comment.