diff --git a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp index fbb5d1c34..443783766 100644 --- a/src/ATen/native/xpu/sycl/HistogramddKernels.cpp +++ b/src/ATen/native/xpu/sycl/HistogramddKernels.cpp @@ -222,9 +222,10 @@ struct HistogramddLinearKernelFunctor { int64_t hist_idx = 0; for (int dim = 0; dim < input_dim_; ++dim) { auto i_value = input_data[ele_idx][dim]; - auto leftmost_edge = leftmost_edges_[dim]; - auto rightmost_edge = rightmost_edges_[dim]; + const scalar_t* bin_edges = bin_edges_list_[dim]; auto bin_size = num_bin_edges_[dim] - 1; + auto leftmost_edge = bin_edges[0]; + auto rightmost_edge = bin_edges[bin_size]; if (!(i_value >= leftmost_edge && i_value <= rightmost_edge)) { return; } @@ -253,9 +254,7 @@ struct HistogramddLinearKernelFunctor { bool use_weight, int64_t input_size, int64_t input_dim, - const int64_t* num_bin_edges, - const scalar_t* leftmost_edges, - const scalar_t* rightmost_edges) + const int64_t* num_bin_edges) : input_(input), bin_edges_list_(bin_edges_list), hist_(hist), @@ -264,9 +263,7 @@ struct HistogramddLinearKernelFunctor { use_weight_(use_weight), input_size_(input_size), input_dim_(input_dim), - num_bin_edges_(num_bin_edges), - leftmost_edges_(leftmost_edges), - rightmost_edges_(rightmost_edges) {} + num_bin_edges_(num_bin_edges) {} private: const PackedTensorAccessor64 input_; @@ -278,8 +275,6 @@ struct HistogramddLinearKernelFunctor { int64_t input_size_; int64_t input_dim_; const int64_t* num_bin_edges_; - const scalar_t* leftmost_edges_; - const scalar_t* rightmost_edges_; }; template @@ -292,9 +287,7 @@ void histogramdd_linear_template( bool use_weight, int64_t input_size, int64_t input_dim, - const int64_t* num_bin_edges, - const scalar_t* leftmost_edges, - const scalar_t* rightmost_edges) { + const int64_t* num_bin_edges) { HistogramddLinearKernelFunctor kfn( input, bin_edges_list, @@ -304,9 +297,7 @@ void histogramdd_linear_template( use_weight, input_size, input_dim, - num_bin_edges, - leftmost_edges, - rightmost_edges); + num_bin_edges); const int64_t work_group_size = syclMaxWorkGroupSize(kfn); const int64_t num_wg = (input_size + work_group_size - 1) / work_group_size; sycl_kernel_submit( @@ -355,12 +346,8 @@ void histogramdd_linear_template( * returns both the hist and bin_edges tensors as output, so the "local search" * is needed to keep its output internally consistent. * - * - PARALLEL_SEARCH: Handles torch.histogram's general case by by searching + * - PARALLEL_SEARCH: Handles torch.histogram's general case by searching * over the elements of bin_edges. - * - * See discussion at - * https://github.com/pytorch/pytorch/pull/58780#discussion_r648604866 for - * further details on relative performance of the bin selection algorithms. */ enum BIN_SELECTION_ALGORITHM { LINEAR_INTERPOLATION, @@ -401,16 +388,12 @@ void histogramdd_xpu_contiguous( std::vector bin_seq(D); std::vector num_bin_edges(D); - std::vector leftmost_edge(D), rightmost_edge(D); int64_t total_bin_size = 1; for (const auto dim : c10::irange(D)) { const input_t* data_ptr = bin_edges[dim].const_data_ptr(); bin_seq[dim] = reinterpret_cast(data_ptr); num_bin_edges[dim] = bin_edges[dim].numel(); - leftmost_edge[dim] = bin_edges[dim][0].item().to(); - rightmost_edge[dim] = - bin_edges[dim][num_bin_edges[dim] - 1].item().to(); total_bin_size += num_bin_edges[dim] - 1; } @@ -420,12 +403,6 @@ void histogramdd_xpu_contiguous( self.options() .dtype(c10::kLong) .memory_format(at::MemoryFormat::Contiguous)); - Tensor leftmost_edges_xpu = at::tensor( - leftmost_edge, - self.options().memory_format(at::MemoryFormat::Contiguous)); - Tensor rightmost_edges_xpu = at::tensor( - rightmost_edge, - self.options().memory_format(at::MemoryFormat::Contiguous)); Tensor bin_edges_contig_ptr_xpu = at::tensor(bin_seq, hist_strides_xpu.options()); Tensor num_bin_edges_xpu = @@ -457,9 +434,7 @@ void histogramdd_xpu_contiguous( weight.has_value(), N, D, - num_bin_edges_xpu.const_data_ptr(), - leftmost_edges_xpu.const_data_ptr(), - rightmost_edges_xpu.const_data_ptr()); + num_bin_edges_xpu.const_data_ptr()); } }