Skip to content

Commit

Permalink
implement parallel_search
Browse files Browse the repository at this point in the history
  • Loading branch information
hjhee committed Aug 29, 2024
1 parent 2b5716e commit 93ade99
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 243 deletions.
59 changes: 3 additions & 56 deletions src/ATen/native/xpu/Histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,27 +267,7 @@ static Tensor& histogramdd_out(
bin_edges[dim].copy_(bins[dim]);
}

const int64_t N = self.size(-1);
const int64_t M = std::accumulate(
self.sizes().begin(),
self.sizes().end() - 1,
(int64_t)1,
std::multiplies<int64_t>());

const Tensor reshaped_self = self.reshape({M, N});

const auto reshaped_weight = weight.has_value()
? std::optional<Tensor>(weight.value().reshape({M}))
: std::optional<Tensor>();

std::vector<Tensor> bin_edges_contig(bin_edges.size());
for (const auto dim : c10::irange(bin_edges_contig.size())) {
bin_edges_contig[dim] = bin_edges[dim].contiguous();
}
TensorList bin_edges_contig_tl(bin_edges_contig);

at::native::xpu::histogramdd_kernel(
reshaped_self, reshaped_weight, density, hist, bin_edges_contig_tl);
at::native::xpu::histogramdd_kernel(self, weight, density, hist, bin_edges);
return hist;
}

Expand Down Expand Up @@ -378,35 +358,8 @@ static Tensor& histogramdd_out(
bin_edges[dim].copy_(bins[dim]);
}

std::vector<Tensor> bin_edges_contig(bin_edges.size());
for (const auto dim : c10::irange(bin_edges_contig.size())) {
bin_edges_contig[dim] = bin_edges[dim].contiguous();
}
TensorList bin_edges_contig_tl(bin_edges_contig);

const int64_t N = self.size(-1);
const int64_t M = std::accumulate(
self.sizes().begin(),
self.sizes().end() - 1,
(int64_t)1,
std::multiplies<int64_t>());

const Tensor reshaped_self = self.reshape({M, N});

const auto reshaped_weight = weight.has_value()
? std::optional<Tensor>(weight.value().reshape({M}))
: std::optional<Tensor>();

auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);

at::native::xpu::histogramdd_linear_kernel(
reshaped_self,
reshaped_weight,
density,
hist,
bin_edges_contig_tl,
outer_bin_edges,
true);
self, weight, density, hist, bin_edges, true);
return hist;
}

Expand Down Expand Up @@ -486,13 +439,7 @@ std::tuple<Tensor&, Tensor&> XPUNativeFunctions::histogram_out(
histogramdd_check_inputs(reshaped_self, bins_in, reshaped_weight);

at::native::xpu::histogramdd_linear_kernel(
reshaped_self,
reshaped_weight,
density,
hist,
bin_edges.contiguous(),
outer_bin_edges,
true);
reshaped_self, reshaped_weight, density, hist, bin_edges, true);
return std::forward_as_tuple(hist, bin_edges);
}

Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/sycl/HistogramKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ void histogramdd_linear_kernel(
bool density,
Tensor& hist,
const TensorList& bin_edges,
const std::pair<std::vector<double>, std::vector<double>>& outer_bin_edges,
bool local_search);

} // namespace at::native::xpu
Loading

0 comments on commit 93ade99

Please sign in to comment.