Skip to content

Commit

Permalink
Fix pdist unit test failure on client GPUs (#4361)(#4415)
Browse files Browse the repository at this point in the history
* Limit wg size in pdist fwd kernel on client gpus

* Replace hardcoded change with a general fix to reduce_agg

* Update reduce_agg for consistency

* Unroll last iteration in the reduction loop to avoid regression on working cases; Reduce local memory allocation in pdist kernel

* Remove condition and format the file
  • Loading branch information
Kanya-Mo authored Jul 12, 2024
1 parent 550fd76 commit 00f9449
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions csrc/gpu/aten/operators/Distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,29 @@ static inline scalar_t reduce_agg(
const int local_id = item.get_local_id(0);
const int lane_id = local_id % sg_size;
const int sg_id = local_id / sg_size;
int reduce_num;
agg = subgroup_reduce_agg<scalar_t, F, nd_item>(item, agg, sg_size);
item.barrier(dpcpp_local_fence);
if (0 == lane_id) {
local_shared_mem[sg_id] = agg;
}
item.barrier(dpcpp_local_fence);
agg = (local_id < sg_num) ? local_shared_mem[lane_id] : (scalar_t)0.0f;
if (0 == sg_id) {

for (reduce_num = sg_num; reduce_num > sg_size;
reduce_num = (reduce_num + sg_size - 1) / sg_size) {
agg = (local_id < reduce_num) ? local_shared_mem[local_id] : (scalar_t)0.0f;
agg = subgroup_reduce_agg<scalar_t, F, nd_item>(item, agg, sg_size);
item.barrier(dpcpp_local_fence);
if (0 == lane_id && local_id < reduce_num) {
local_shared_mem[sg_id] = agg;
}
item.barrier(dpcpp_local_fence);
}

agg = (local_id < reduce_num) ? local_shared_mem[lane_id] : (scalar_t)0.0f;
if (0 == sg_id) {
agg = subgroup_reduce_agg<scalar_t, F, nd_item>(item, agg, sg_size);
}
return agg;
}

Expand Down Expand Up @@ -333,6 +345,7 @@ static void pdist_kernel_impl(
const auto ngroups = result.numel();
auto& dpcpp_queue = dpcppGetCurrentQueue();
auto dev_id = dpcppGetDeviceIdOfCurrentQueue();
auto min_sg_size = dpcppMinSubGroupSize(dev_id);
auto wgroup_size = dpcppMaxWorkGroupSize(dev_id);
while (wgroup_size >> 1 >= m && wgroup_size >> 1 >= 32 /* sg_size */) {
wgroup_size >>= 1;
Expand All @@ -346,7 +359,7 @@ static void pdist_kernel_impl(
auto out_data = result.data_ptr<scalar_t>();
auto in_data = self.data_ptr<scalar_t>();
// Create the local shared memory for reducing
dpcpp_local_acc_t<scalar_t, 1> shared(wgroup_size, __cgh);
dpcpp_local_acc_t<scalar_t, 1> shared(wgroup_size / min_sg_size, __cgh);

PdistKernelImplFunctor<scalar_t, F, p_tpye, accscalar_t> kfn(
n, m, p_val, n2_val, n2_squared_minus_1_val, out_data, in_data, shared);
Expand Down

0 comments on commit 00f9449

Please sign in to comment.