Skip to content

Commit

Permalink
Cleanup csr2csr compress (#743) (#382)
Browse files Browse the repository at this point in the history
* Testing

* Fix access to warp start array

* Testing

* Testing

* Testing

* Add asserts and remove comments

* Clang formatting

* PR fixes

* Fix to bsrmm

* PR fixes

---------

Co-authored-by: jsandham <james.sandham@amd.com>
  • Loading branch information
jsandham and jsandham authored Apr 3, 2024
1 parent 3f96fa5 commit edb2770
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 88 deletions.
8 changes: 4 additions & 4 deletions clients/tests/test_csr2csr_compress.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ########################################################################
# Copyright (C) 2020-2023 Advanced Micro Devices, Inc. All rights Reserved.
# Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -95,7 +95,7 @@ Tests:
- name: csr2csr_compress_file
category: quick
function: csr2csr_compress
precision: *single_double_precisions
precision: *single_double_precisions_complex_real
M: 1
N: 1
alpha_alphai: *tol_range_quick
Expand All @@ -110,7 +110,7 @@ Tests:
- name: csr2csr_compress_file
category: pre_checkin
function: csr2csr_compress
precision: *single_double_precisions
precision: *single_double_precisions_complex_real
M: 1
N: 1
alpha_alphai: *tol_range_checkin
Expand All @@ -127,7 +127,7 @@ Tests:
- name: csr2csr_compress_file
category: nightly
function: csr2csr_compress
precision: *single_double_precisions
precision: *single_double_precisions_complex_real
M: 1
N: 1
alpha_alphai: *tol_range_nightly
Expand Down
164 changes: 90 additions & 74 deletions library/src/conversion/csr2csr_compress_device.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*! \file */
/* ************************************************************************
* Copyright (C) 2020-2023 Advanced Micro Devices, Inc. All rights Reserved.
* Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights Reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -24,20 +24,21 @@

#pragma once

#include <assert.h>
#include <limits>

#include "common.h"

template <rocsparse_int BLOCK_SIZE, rocsparse_int WF_SIZE, rocsparse_int LOOPS, typename T>
ROCSPARSE_KERNEL(BLOCK_SIZE)
template <uint32_t BLOCKSIZE, uint32_t WF_SIZE, uint32_t LOOPS, typename T>
ROCSPARSE_KERNEL(BLOCKSIZE)
void csr2csr_compress_fill_warp_start_device(rocsparse_int nnz_A,
const T* __restrict__ csr_val_A,
int* __restrict__ warp_start,
T tol)
const T* __restrict__ csr_val_A,
uint32_t* __restrict__ warp_start,
T tol)
{
rocsparse_int tid = hipThreadIdx_x;
rocsparse_int bid = hipBlockIdx_x;
rocsparse_int gid = tid + LOOPS * BLOCK_SIZE * bid;
rocsparse_int gid = tid + LOOPS * BLOCKSIZE * bid;

rocsparse_int wid = tid / WF_SIZE;

Expand All @@ -46,87 +47,100 @@ void csr2csr_compress_fill_warp_start_device(rocsparse_int nnz_A,
warp_start[0] = 0;
}

for(rocsparse_int i = 0; i < LOOPS; i++)
for(uint32_t i = 0; i < LOOPS; i++)
{
T value = (gid < nnz_A) ? rocsparse_nontemporal_load(csr_val_A + gid) : static_cast<T>(0);
if(gid < nnz_A)
{
const T value = rocsparse_nontemporal_load(csr_val_A + gid);

// Check if value in matrix will be kept
const bool predicate
= (rocsparse_abs(value) > rocsparse_real(tol)
&& rocsparse_abs(value) > std::numeric_limits<float>::min());

// Inactive threads in warp set their lane to zero in mask
const uint64_t wavefront_mask = __ballot(predicate);

// Check if value in matrix will be kept
bool predicate = rocsparse_abs(value) > rocsparse_real(tol)
&& rocsparse_abs(value) > std::numeric_limits<float>::min()
? true
: false;
// Get the number of retained matrix entries in this warp
const uint32_t count_nnzs = __popcll(wavefront_mask);

const uint64_t wavefront_mask = __ballot(predicate);
const int warp_index
= (LOOPS * (BLOCKSIZE / WF_SIZE) * bid + (BLOCKSIZE / WF_SIZE) * i + wid);

// Get the number of retained matrix entries in this warp
const uint64_t count_nnzs = __popcll(wavefront_mask);
assert(warp_index < ((nnz_A - 1) / WF_SIZE + 1) && "Warp index out of bounds");

warp_start[LOOPS * (BLOCK_SIZE / WF_SIZE) * bid + (BLOCK_SIZE / WF_SIZE) * i + wid + 1]
= static_cast<int>(count_nnzs);
warp_start[warp_index + 1] = count_nnzs;
}

gid += BLOCK_SIZE;
gid += BLOCKSIZE;
}
}

template <rocsparse_int BLOCK_SIZE, rocsparse_int WF_SIZE, rocsparse_int LOOPS, typename T>
ROCSPARSE_KERNEL(BLOCK_SIZE)
template <uint32_t BLOCKSIZE, uint32_t WF_SIZE, uint32_t LOOPS, typename T>
ROCSPARSE_KERNEL(BLOCKSIZE)
void csr2csr_compress_use_warp_start_device(rocsparse_int nnz_A,
rocsparse_index_base idx_base_A,
const T* __restrict__ csr_val_A,
const rocsparse_int* __restrict__ csr_col_ind_A,
rocsparse_index_base idx_base_C,
T* __restrict__ csr_val_C,
rocsparse_int* __restrict__ csr_col_ind_C,
int* __restrict__ warp_start,
const uint32_t* __restrict__ warp_start,
T tol)
{
rocsparse_int tid = hipThreadIdx_x;
rocsparse_int bid = hipBlockIdx_x;
rocsparse_int gid = tid + LOOPS * BLOCK_SIZE * bid;
rocsparse_int gid = tid + LOOPS * BLOCKSIZE * bid;

rocsparse_int lid = tid & (WF_SIZE - 1);
rocsparse_int wid = tid / WF_SIZE;

const uint64_t filter_mask = (0xffffffffffffffff >> (63 - lid));

for(rocsparse_int i = 0; i < LOOPS; i++)
for(uint32_t i = 0; i < LOOPS; i++)
{
int start
= warp_start[LOOPS * (BLOCK_SIZE / WF_SIZE) * bid + (BLOCK_SIZE / WF_SIZE) * i + wid];
if(gid < nnz_A)
{
const T value = rocsparse_nontemporal_load(csr_val_A + gid);

T value = (gid < nnz_A) ? rocsparse_nontemporal_load(csr_val_A + gid) : static_cast<T>(0);
// Check if value in matrix will be kept
const bool predicate
= (rocsparse_abs(value) > rocsparse_real(tol)
&& rocsparse_abs(value) > std::numeric_limits<float>::min());

// Check if value in matrix will be kept
bool predicate = rocsparse_abs(value) > rocsparse_real(tol)
&& rocsparse_abs(value) > std::numeric_limits<float>::min()
? true
: false;
// Inactive threads in warp set their lane to zero in mask
const uint64_t wavefront_mask = __ballot(predicate);

const uint64_t wavefront_mask = __ballot(predicate);
// Get the number of retained matrix entries in this warp
const uint32_t count_previous_nnzs = __popcll(wavefront_mask & filter_mask);

// Get the number of retained matrix entries in this warp
const uint64_t count_previous_nnzs = __popcll(wavefront_mask & filter_mask);
// If we are keeping the matrix entry, insert it into the compressed CSR matrix
if(predicate)
{
assert(count_previous_nnzs > 0
&& "When predicate is true, non-zero count cannot be zero.");

// If we are keeping the matrix entry, insert it into the compressed CSR matrix
if(predicate)
{
csr_val_C[start + count_previous_nnzs - 1] = value;
csr_col_ind_C[start + count_previous_nnzs - 1]
= csr_col_ind_A[gid] - idx_base_A + idx_base_C;
const uint32_t start = warp_start[LOOPS * (BLOCKSIZE / WF_SIZE) * bid
+ (BLOCKSIZE / WF_SIZE) * i + wid];

csr_val_C[start + count_previous_nnzs - 1] = value;
csr_col_ind_C[start + count_previous_nnzs - 1]
= csr_col_ind_A[gid] - idx_base_A + idx_base_C;
}
}

gid += BLOCK_SIZE;
gid += BLOCKSIZE;
}
}

template <rocsparse_int BLOCK_SIZE>
ROCSPARSE_KERNEL(BLOCK_SIZE)
template <uint32_t BLOCKSIZE>
ROCSPARSE_KERNEL(BLOCKSIZE)
void fill_row_ptr_device(rocsparse_int m,
rocsparse_index_base idx_base_C,
const rocsparse_int* __restrict__ nnz_per_row,
rocsparse_int* __restrict__ csr_row_ptr_C)
rocsparse_index_base idx_base_C,
const rocsparse_int* __restrict__ nnz_per_row,
rocsparse_int* __restrict__ csr_row_ptr_C)
{
rocsparse_int tid = hipThreadIdx_x + hipBlockIdx_x * BLOCK_SIZE;
rocsparse_int tid = hipThreadIdx_x + hipBlockIdx_x * BLOCKSIZE;

if(tid >= m)
{
Expand All @@ -141,31 +155,33 @@ void fill_row_ptr_device(rocsparse_int m,
}
}

template <rocsparse_int BLOCK_SIZE,
rocsparse_int SEGMENTS_PER_BLOCK,
rocsparse_int SEGMENT_SIZE,
rocsparse_int WF_SIZE,
typename T>
ROCSPARSE_DEVICE_ILF void csr2csr_compress_device(rocsparse_int m,
rocsparse_int n,
rocsparse_index_base idx_base_A,
const T* __restrict__ csr_val_A,
const rocsparse_int* __restrict__ csr_row_ptr_A,
const rocsparse_int* __restrict__ csr_col_ind_A,
rocsparse_int nnz_A,
rocsparse_index_base idx_base_C,
T* __restrict__ csr_val_C,
const rocsparse_int* __restrict__ csr_row_ptr_C,
rocsparse_int* __restrict__ csr_col_ind_C,
T tol)
template <uint32_t BLOCKSIZE,
uint32_t SEGMENTS_PER_BLOCK,
uint32_t SEGMENT_SIZE,
uint32_t WF_SIZE,
typename T>
ROCSPARSE_DEVICE_ILF void
csr2csr_compress_device(rocsparse_int m,
rocsparse_int n,
rocsparse_index_base idx_base_A,
const T* __restrict__ csr_val_A,
const rocsparse_int* __restrict__ csr_row_ptr_A,
const rocsparse_int* __restrict__ csr_col_ind_A,
rocsparse_int nnz_A,
rocsparse_index_base idx_base_C,
T* __restrict__ csr_val_C,
const rocsparse_int* __restrict__ csr_row_ptr_C,
rocsparse_int* __restrict__ csr_col_ind_C,
T tol)
{
const rocsparse_int segment_id = hipThreadIdx_x / SEGMENT_SIZE;
const rocsparse_int segment_lane_id = hipThreadIdx_x % SEGMENT_SIZE;

const rocsparse_int id_of_segment_within_warp = segment_id % (WF_SIZE / SEGMENT_SIZE);

const uint64_t filter_mask = (0xffffffffffffffff >> (63 - segment_lane_id));
const uint64_t shifted_filter_mask = filter_mask << (SEGMENT_SIZE * id_of_segment_within_warp);
const uint64_t shifted_filter_mask = filter_mask
<< (SEGMENT_SIZE * id_of_segment_within_warp);

const rocsparse_int row_index = SEGMENTS_PER_BLOCK * hipBlockIdx_x + segment_id;

Expand All @@ -182,11 +198,11 @@ ROCSPARSE_DEVICE_ILF void csr2csr_compress_device(rocsparse_int m,
const T value = csr_val_A[i];

// Check if value in matrix will be kept
const bool predicate
const int predicate
= rocsparse_abs(value) > rocsparse_real(tol)
&& rocsparse_abs(value) > std::numeric_limits<float>::min()
? true
: false;
&& rocsparse_abs(value) > std::numeric_limits<float>::min()
? 1
: 0;

// Ballot operates on an entire warp (32 or 64 threads). Therefore the computed
// wavefront_mask may contain information for multiple rows if the segment size is
Expand All @@ -209,8 +225,8 @@ ROCSPARSE_DEVICE_ILF void csr2csr_compress_device(rocsparse_int m,
// Broadcast the update of the start_C to all threads in the seegment. Choose the last
// segment lane since that it contains the number of entries in the compressed sparse
// row (even if its predicate is false).
start_C
+= __shfl(static_cast<int>(count_previous_nnzs), SEGMENT_SIZE - 1, SEGMENT_SIZE);
start_C += __shfl(
static_cast<int>(count_previous_nnzs), SEGMENT_SIZE - 1, SEGMENT_SIZE);
}
}
}
}
22 changes: 12 additions & 10 deletions library/src/conversion/rocsparse_csr2csr_compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,30 @@ rocsparse_status rocsparse_csr2csr_compress_template(rocsparse_handle h

// Stream
hipStream_t stream = handle->stream;

// Compute required temporary storage buffer size
size_t nwarps = (nnz_A - 1) / handle->wavefront_size + 1;
size_t temp_storage_size_bytes1 = sizeof(int) * (nwarps / 256 + 1) * 256;
size_t temp_storage_size_bytes1 = sizeof(uint32_t) * (nwarps / 256 + 1) * 256;

auto op = rocprim::plus<rocsparse_int>();
// Compute buffer size for inclusive scan on csr_row_ptr_C
size_t temp_storage_size_bytes2;
RETURN_IF_HIP_ERROR(rocprim::inclusive_scan(nullptr,
temp_storage_size_bytes2,
(rocsparse_int*)nullptr,
(rocsparse_int*)nullptr,
m + 1,
op,
rocprim::plus<rocsparse_int>(),
stream));
temp_storage_size_bytes2 = ((temp_storage_size_bytes2 - 1) / 256 + 1) * 256;

// Compute buffer size for inclusive scan on warp_start
size_t temp_storage_size_bytes3;
RETURN_IF_HIP_ERROR(rocprim::inclusive_scan(nullptr,
temp_storage_size_bytes3,
(rocsparse_int*)nullptr,
(rocsparse_int*)nullptr,
(uint32_t*)nullptr,
(uint32_t*)nullptr,
nwarps + 1,
op,
rocprim::plus<uint32_t>(),
stream));
temp_storage_size_bytes3 = ((temp_storage_size_bytes3 - 1) / 256 + 1) * 256;

Expand All @@ -128,8 +130,8 @@ rocsparse_status rocsparse_csr2csr_compress_template(rocsparse_handle h
temp_alloc = true;
}

char* ptr = reinterpret_cast<char*>(temp_storage_ptr);
int* warp_start = reinterpret_cast<int*>(ptr);
char* ptr = reinterpret_cast<char*>(temp_storage_ptr);
uint32_t* warp_start = reinterpret_cast<uint32_t*>(ptr);
ptr += temp_storage_size_bytes1;
void* temp_storage_buffer2 = ptr;
ptr += temp_storage_size_bytes2;
Expand All @@ -153,7 +155,7 @@ rocsparse_status rocsparse_csr2csr_compress_template(rocsparse_handle h
csr_row_ptr_C,
csr_row_ptr_C,
m + 1,
op,
rocprim::plus<rocsparse_int>(),
stream));

if(csr_val_C == nullptr || csr_col_ind_C == nullptr)
Expand Down Expand Up @@ -214,7 +216,7 @@ rocsparse_status rocsparse_csr2csr_compress_template(rocsparse_handle h
warp_start,
warp_start,
nwarps + 1,
op,
rocprim::plus<uint32_t>(),
stream));

if(handle->wavefront_size == 32)
Expand Down

0 comments on commit edb2770

Please sign in to comment.