Skip to content

Commit

Permalink
Tacho - use restrict keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
kyungjoo-kim committed Jul 23, 2020
1 parent 91fddef commit 1531f38
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 40 deletions.
76 changes: 38 additions & 38 deletions packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Tacho {
void set(MemberType &member,
int m,
const T alpha,
/* */ T *a, int as0) {
/* */ T *__restrict__ a, int as0) {
Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) {
a[i*as0] = alpha;
});
Expand All @@ -30,7 +30,7 @@ namespace Tacho {
void scale(MemberType &member,
int m,
const T alpha,
/* */ T *a, int as0) {
/* */ T *__restrict__ a, int as0) {
Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) {
a[i*as0] *= alpha;
});
Expand All @@ -42,7 +42,7 @@ namespace Tacho {
void set(MemberType &member,
int m, int n,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
if (as0 == 1 || as0 < as1)
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) {
Expand All @@ -63,7 +63,7 @@ namespace Tacho {
void scale(MemberType &member,
int m, int n,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
if (as0 == 1 || as0 < as1)
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) {
Expand All @@ -84,7 +84,7 @@ namespace Tacho {
void set_upper(MemberType &member,
int m, int n, int offset,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1-offset),[&](const int &i) {
a[i*as0+j*as1] = alpha;
Expand All @@ -98,7 +98,7 @@ namespace Tacho {
void scale_upper(MemberType &member,
int m, int n, int offset,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1-offset),[&](const int &i) {
a[i*as0+j*as1] *= alpha;
Expand All @@ -113,7 +113,7 @@ namespace Tacho {
void set_lower(MemberType &member,
int m, int n, int offset,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
const int jj = j + offset;
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n-j-offset),[&](const int &i) {
Expand All @@ -128,7 +128,7 @@ namespace Tacho {
void scale_lower(MemberType &member,
int m, int n, int offset,
const T alpha,
/* */ T *a, int as0, int as1) {
/* */ T *__restrict__ a, int as0, int as1) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) {
const int jj = j + offset;
Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n-j-offset),[&](const int &i) {
Expand All @@ -143,10 +143,10 @@ namespace Tacho {
void gemv(MemberType &member, const ConjType &cj,
const int m, const int n,
const T alpha,
const T *A, const int as0, const int as1,
const T *x, const int xs0,
const T *__restrict__ A, const int as0, const int as1,
const T *__restrict__ x, const int xs0,
const T beta,
/* */ T *y, const int ys0) {
/* */ T *__restrict__ y, const int ys0) {
const T one(1), zero(0);

if (beta == zero) set (member, m, zero, y, ys0);
Expand Down Expand Up @@ -184,8 +184,8 @@ namespace Tacho {
void trsv_upper(MemberType &member, const ConjType &cjA,
const char diag,
const int m,
const T *A, const int as0, const int as1,
/* */ T *b, const int bs0) {
const T *__restrict__ A, const int as0, const int as1,
/* */ T *__restrict__ b, const int bs0) {
if (m <= 0) return;

const bool use_unit_diag = diag == 'U'|| diag == 'u';
Expand Down Expand Up @@ -222,8 +222,8 @@ namespace Tacho {
void trsv_lower(MemberType &member, const ConjType &cjA,
const char diag,
const int m,
const T *A, const int as0, const int as1,
/* */ T *b, const int bs0) {
const T *__restrict__ A, const int as0, const int as1,
/* */ T *__restrict__ b, const int bs0) {
if (m <= 0) return;

const bool use_unit_diag = diag == 'U'|| diag == 'u';
Expand Down Expand Up @@ -264,10 +264,10 @@ namespace Tacho {
void gemm(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB,
const int m, const int n, const int k,
const T alpha,
const T *A, const int as0, const int as1,
const T *B, const int bs0, const int bs1,
const T *__restrict__ A, const int as0, const int as1,
const T *__restrict__ B, const int bs0, const int bs1,
const T beta,
/* */ T *C, const int cs0, const int cs1) {
/* */ T *__restrict__ C, const int cs0, const int cs1) {
const T one(1), zero(0);

if (beta == zero) set (member, m, n, zero, C, cs0, cs1);
Expand Down Expand Up @@ -300,9 +300,9 @@ namespace Tacho {
void herk_upper(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB,
const int n, const int k,
const T alpha,
const T *A, const int as0, const int as1,
const T *__restrict__ A, const int as0, const int as1,
const T beta,
/* */ T *C, const int cs0, const int cs1) {
/* */ T *__restrict__ C, const int cs0, const int cs1) {
const T one(1), zero(0);

if (beta == zero) set_upper (member, n, n, 0, zero, C, cs0, cs1);
Expand Down Expand Up @@ -334,9 +334,9 @@ namespace Tacho {
void herk_lower(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB,
const int n, const int k,
const T alpha,
const T *A, const int as0, const int as1,
const T *__restrict__ A, const int as0, const int as1,
const T beta,
/* */ T *C, const int cs0, const int cs1) {
/* */ T *__restrict__ C, const int cs0, const int cs1) {
const T one(1), zero(0);

if (beta == zero) set_lower (member, n, n, 0, zero, C, cs0, cs1);
Expand Down Expand Up @@ -371,8 +371,8 @@ namespace Tacho {
const char diag,
const int m, const int n,
const T alpha,
const T *A, const int as0, const int as1,
/* */ T *B, const int bs0, const int bs1) {
const T *__restrict__ A, const int as0, const int as1,
/* */ T *__restrict__ B, const int bs0, const int bs1) {
const T one(1), zero(0);

if (alpha == zero) set (member, m, n, zero, B, bs0, bs1);
Expand Down Expand Up @@ -417,8 +417,8 @@ namespace Tacho {
const char diag,
const int m, const int n,
const T alpha,
const T *A, const int as0, const int as1,
/* */ T *B, const int bs0, const int bs1) {
const T *__restrict__ A, const int as0, const int as1,
/* */ T *__restrict__ B, const int bs0, const int bs1) {
const T one(1.0), zero(0.0);

// note that parallel range is different ( m*n vs m-1*n);
Expand Down Expand Up @@ -462,10 +462,10 @@ namespace Tacho {
const char trans,
const int m, const int n,
const T alpha,
const T *a, const int lda,
const T *x, const int xs,
const T *__restrict__ a, const int lda,
const T *__restrict__ x, const int xs,
const T beta,
/* */ T *y, const int ys) {
/* */ T *__restrict__ y, const int ys) {
switch (trans) {
case 'N':
case 'n': {
Expand Down Expand Up @@ -515,8 +515,8 @@ namespace Tacho {
void trsv(MemberType &member,
const char uplo, const char trans, const char diag,
const int m,
const T *a, const int lda,
/* */ T *b, const int bs) {
const T *__restrict__ a, const int lda,
/* */ T *__restrict__ b, const int bs) {
if (uplo == 'U' || uplo == 'u') {
switch (trans) {
case 'N':
Expand Down Expand Up @@ -591,10 +591,10 @@ namespace Tacho {
const char transa, const char transb,
const int m, const int n, const int k,
const T alpha,
const T *a, int lda,
const T *b, int ldb,
const T *__restrict__ a, int lda,
const T *__restrict__ b, int ldb,
const T beta,
/* */ T *c, int ldc) {
/* */ T *__restrict__ c, int ldc) {

if (transa == 'N' || transa == 'n') {
const NoConjugate cjA;
Expand Down Expand Up @@ -735,9 +735,9 @@ namespace Tacho {
const char uplo, const char trans,
const int n, const int k,
const T alpha,
const T *a, const int lda,
const T *__restrict__ a, const int lda,
const T beta,
/* */ T *c, const int ldc) {
/* */ T *__restrict__ c, const int ldc) {
if (uplo == 'U' || uplo == 'u')
switch (trans) {
case 'N':
Expand Down Expand Up @@ -808,8 +808,8 @@ namespace Tacho {
const char side, const char uplo, const char trans, const char diag,
const int m, const int n,
const T alpha,
const T *a, const int lda,
/* */ T *b, const int ldb) {
const T *__restrict__ a, const int lda,
/* */ T *__restrict__ b, const int ldb) {
///
/// side left
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Tacho {
KOKKOS_INLINE_FUNCTION
void potrf_upper(MemberType &member,
const int m,
T *A, const int as0, const int as1,
T *__restrict__ A, const int as0, const int as1,
int *info) {
if (m <= 0) return;

Expand Down Expand Up @@ -106,7 +106,7 @@ namespace Tacho {
void potrf(MemberType &member,
const char uplo,
const int m,
/* */ T *A, const int lda,
/* */ T *__restrict__ A, const int lda,
int *info) {
switch (uplo) {
case 'U':
Expand Down

0 comments on commit 1531f38

Please sign in to comment.