Skip to content

Commit

Permalink
ensure block jac diag is not zero
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed May 16, 2024
1 parent dee1062 commit 37e71d8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
3 changes: 2 additions & 1 deletion common/cuda_hip/preconditioner/batch_block_jacobi.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ public:
sum += block_val * r[dense_block_col + idx_start];
}

// reduction
// reduction
#pragma unroll
for (int i = static_cast<int>(tile_size) / 2; i > 0; i /= 2) {
sum += subwarp_grp.shfl_down(sum, i);
}
Expand Down
8 changes: 5 additions & 3 deletions common/cuda_hip/preconditioner/batch_jacobi_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ __global__ __launch_bounds__(default_block_size) void find_row_block_map_kernel(
block_idx += blockDim.x * gridDim.x) {
for (int i = block_pointers[block_idx];
i < block_pointers[block_idx + 1]; i++) {
map_block_to_row[i] = block_idx; // uncoalesced
// accesses
map_block_to_row[i] = block_idx; // uncoalesced accesses
}
}
}
Expand Down Expand Up @@ -126,7 +125,10 @@ __device__ __forceinline__ void invert_dense_block(Group subwarp_grp,
if (subwarp_grp.thread_rank() == ipiv) {
perm = k;
}
const ValueType d = subwarp_grp.shfl(block_row[k], ipiv);
const ValueType d =
(subwarp_grp.shfl(block_row[k], ipiv) == zero<ValueType>())
? one<ValueType>()
: subwarp_grp.shfl(block_row[k], ipiv);
// scale kth col
block_row[k] /= -d;
if (subwarp_grp.thread_rank() == ipiv) {
Expand Down
6 changes: 3 additions & 3 deletions dpcpp/preconditioner/batch_jacobi_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void find_row_block_map(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* const block_pointers,
IndexType* const map_block_to_row)
{
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(num_blocks, [=](auto id) {
for (int i = block_pointers[id]; i < block_pointers[id + 1]; i++)
map_block_to_row[i] = id;
Expand Down Expand Up @@ -97,7 +97,7 @@ void extract_common_blocks_pattern(
const auto row_ptrs = first_sys_csr->get_const_row_ptrs();
const auto col_idxs = first_sys_csr->get_const_col_idxs();

(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(subgroup_size)]] {
Expand Down Expand Up @@ -143,7 +143,7 @@ void compute_block_jacobi_helper(
dim3 block(group_size);
dim3 grid(ceildiv(num_blocks * nbatch * subgroup_size, group_size));

(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[intel::reqd_sub_group_size(subgroup_size)]] {
Expand Down

0 comments on commit 37e71d8

Please sign in to comment.