Skip to content

Commit

Permalink
reduce H2D/D2H copy
Browse files Browse the repository at this point in the history
  • Loading branch information
hjhee committed Oct 14, 2024
1 parent af97db0 commit 12f7731
Showing 1 changed file with 9 additions and 34 deletions.
43 changes: 9 additions & 34 deletions src/ATen/native/xpu/sycl/HistogramddKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ struct HistogramddLinearKernelFunctor {
int64_t hist_idx = 0;
for (int dim = 0; dim < input_dim_; ++dim) {
auto i_value = input_data[ele_idx][dim];
auto leftmost_edge = leftmost_edges_[dim];
auto rightmost_edge = rightmost_edges_[dim];
const scalar_t* bin_edges = bin_edges_list_[dim];
auto bin_size = num_bin_edges_[dim] - 1;
auto leftmost_edge = bin_edges[0];
auto rightmost_edge = bin_edges[bin_size];
if (!(i_value >= leftmost_edge && i_value <= rightmost_edge)) {
return;
}
Expand Down Expand Up @@ -253,9 +254,7 @@ struct HistogramddLinearKernelFunctor {
bool use_weight,
int64_t input_size,
int64_t input_dim,
const int64_t* num_bin_edges,
const scalar_t* leftmost_edges,
const scalar_t* rightmost_edges)
const int64_t* num_bin_edges)
: input_(input),
bin_edges_list_(bin_edges_list),
hist_(hist),
Expand All @@ -264,9 +263,7 @@ struct HistogramddLinearKernelFunctor {
use_weight_(use_weight),
input_size_(input_size),
input_dim_(input_dim),
num_bin_edges_(num_bin_edges),
leftmost_edges_(leftmost_edges),
rightmost_edges_(rightmost_edges) {}
num_bin_edges_(num_bin_edges) {}

private:
const PackedTensorAccessor64<const scalar_t, 2> input_;
Expand All @@ -278,8 +275,6 @@ struct HistogramddLinearKernelFunctor {
int64_t input_size_;
int64_t input_dim_;
const int64_t* num_bin_edges_;
const scalar_t* leftmost_edges_;
const scalar_t* rightmost_edges_;
};

template <typename scalar_t>
Expand All @@ -292,9 +287,7 @@ void histogramdd_linear_template(
bool use_weight,
int64_t input_size,
int64_t input_dim,
const int64_t* num_bin_edges,
const scalar_t* leftmost_edges,
const scalar_t* rightmost_edges) {
const int64_t* num_bin_edges) {
HistogramddLinearKernelFunctor<scalar_t> kfn(
input,
bin_edges_list,
Expand All @@ -304,9 +297,7 @@ void histogramdd_linear_template(
use_weight,
input_size,
input_dim,
num_bin_edges,
leftmost_edges,
rightmost_edges);
num_bin_edges);
const int64_t work_group_size = syclMaxWorkGroupSize(kfn);
const int64_t num_wg = (input_size + work_group_size - 1) / work_group_size;
sycl_kernel_submit(
Expand Down Expand Up @@ -355,12 +346,8 @@ void histogramdd_linear_template(
* returns both the hist and bin_edges tensors as output, so the "local search"
* is needed to keep its output internally consistent.
*
* - PARALLEL_SEARCH: Handles torch.histogram's general case by by searching
* - PARALLEL_SEARCH: Handles torch.histogram's general case by searching
* over the elements of bin_edges.
*
* See discussion at
* https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866 for
* further details on relative performance of the bin selection algorithms.
*/
enum BIN_SELECTION_ALGORITHM {
LINEAR_INTERPOLATION,
Expand Down Expand Up @@ -401,16 +388,12 @@ void histogramdd_xpu_contiguous(

std::vector<int64_t> bin_seq(D);
std::vector<int64_t> num_bin_edges(D);
std::vector<input_t> leftmost_edge(D), rightmost_edge(D);

int64_t total_bin_size = 1;
for (const auto dim : c10::irange(D)) {
const input_t* data_ptr = bin_edges[dim].const_data_ptr<input_t>();
bin_seq[dim] = reinterpret_cast<int64_t>(data_ptr);
num_bin_edges[dim] = bin_edges[dim].numel();
leftmost_edge[dim] = bin_edges[dim][0].item().to<input_t>();
rightmost_edge[dim] =
bin_edges[dim][num_bin_edges[dim] - 1].item().to<input_t>();

total_bin_size += num_bin_edges[dim] - 1;
}
Expand All @@ -420,12 +403,6 @@ void histogramdd_xpu_contiguous(
self.options()
.dtype(c10::kLong)
.memory_format(at::MemoryFormat::Contiguous));
Tensor leftmost_edges_xpu = at::tensor(
leftmost_edge,
self.options().memory_format(at::MemoryFormat::Contiguous));
Tensor rightmost_edges_xpu = at::tensor(
rightmost_edge,
self.options().memory_format(at::MemoryFormat::Contiguous));
Tensor bin_edges_contig_ptr_xpu =
at::tensor(bin_seq, hist_strides_xpu.options());
Tensor num_bin_edges_xpu =
Expand Down Expand Up @@ -457,9 +434,7 @@ void histogramdd_xpu_contiguous(
weight.has_value(),
N,
D,
num_bin_edges_xpu.const_data_ptr<int64_t>(),
leftmost_edges_xpu.const_data_ptr<input_t>(),
rightmost_edges_xpu.const_data_ptr<input_t>());
num_bin_edges_xpu.const_data_ptr<int64_t>());
}
}

Expand Down

0 comments on commit 12f7731

Please sign in to comment.