diff --git a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp index bd36cdd955..37c7505374 100644 --- a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp +++ b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp @@ -416,7 +416,7 @@ void __do_gemm_serial_batched_template(options_t options, C = Kokkos::subview(_gemm_args.C, Kokkos::ALL(), Kokkos::ALL(), j); } - SerialGemm::invoke( + KokkosBatched::SerialGemm::invoke( _gemm_args.alpha, A, B, _gemm_args.beta, C); } } @@ -445,9 +445,9 @@ template void __do_gemm_parallel_batched_heuristic_template(options_t options, gemm_args_t gemm_args) { - BatchedGemmHandle batchedGemmHandle(BaseHeuristicAlgos::SQUARE); + KokkosBatched::BatchedGemmHandle batchedGemmHandle( + KokkosBatched::BaseHeuristicAlgos::SQUARE); char a = toupper(gemm_args.transA); char b = toupper(gemm_args.transB); - using N = Trans::NoTranspose; - using T = Trans::Transpose; - // using C = Trans::ConjTranspose; + using N = KokkosBatched::Trans::NoTranspose; + using T = KokkosBatched::Trans::Transpose; + // using C = KokkosBatched::Trans::ConjTranspose; + using KokkosBatched::BatchLayout; STATUS; if (a == 'N' && b == 'N') { @@ -950,9 +952,9 @@ template ::value ? 'N' : 'T'; - char transb = std::is_same::value ? 'N' : 'T'; + char transa = + std::is_same::value ? 'N' + : 'T'; + char transb = + std::is_same::value ? 'N' + : 'T'; if (!std::is_same::value) FATAL_ERROR("only double scalars are supported!"); @@ -2229,18 +2235,20 @@ void do_gemm_serial_blas(options_t options) { void do_gemm_serial_batched(options_t options) { STATUS; __do_loop_and_invoke( - options, __do_gemm_serial_batched); + options, + __do_gemm_serial_batched); return; } void do_gemm_serial_batched_blocked(options_t options) { STATUS; __do_loop_and_invoke( - options, __do_gemm_serial_batched); + options, + __do_gemm_serial_batched); return; } @@ -2263,12 +2271,14 @@ void do_gemm_serial_batched_parallel(options_t options) { if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( options, - __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched< + SerialTag, KokkosBatched::Algo::Gemm::Unblocked, default_device>); return; } @@ -2276,13 +2286,14 @@ void do_gemm_serial_batched_blocked_parallel(options_t options) { STATUS; if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( - options, - __do_gemm_parallel_batched); + options, __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched< + SerialTag, KokkosBatched::Algo::Gemm::Blocked, default_device>); return; } @@ -2293,13 +2304,14 @@ void do_gemm_serial_simd_batched_parallel(options_t options) { options.use_simd = true; if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( - options, - __do_gemm_parallel_batched); + options, __do_gemm_parallel_batched< + TeamSimdBatchDim4Tag, KokkosBatched::Algo::Gemm::Unblocked, + default_device, KokkosBatched::Mode::Serial>); else - __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + __do_loop_and_invoke(options, + __do_gemm_parallel_batched< + TeamSimdTag, KokkosBatched::Algo::Gemm::Unblocked, + default_device, KokkosBatched::Mode::Serial>); return; } @@ -2310,13 +2322,14 @@ void do_gemm_serial_simd_batched_blocked_parallel(options_t options) { options.use_simd = true; if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( - options, - __do_gemm_parallel_batched); + options, __do_gemm_parallel_batched< + TeamSimdBatchDim4Tag, KokkosBatched::Algo::Gemm::Blocked, + default_device, KokkosBatched::Mode::Serial>); else - __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + __do_loop_and_invoke(options, + __do_gemm_parallel_batched< + TeamSimdTag, KokkosBatched::Algo::Gemm::Blocked, + default_device, KokkosBatched::Mode::Serial>); return; } @@ -2329,11 +2342,13 @@ void do_gemm_serial_batched_compact_mkl_parallel(options_t options) { __do_loop_and_invoke( options, __do_gemm_parallel_batched); + KokkosBatched::Algo::Gemm::CompactMKL, + default_device>); else __do_loop_and_invoke( options, - __do_gemm_parallel_batched); return; } @@ -2367,12 +2382,14 @@ void do_gemm_team_batched_parallel(options_t options) { if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( options, - __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched< + TeamTag, KokkosBatched::Algo::Gemm::Unblocked, default_device>); return; } @@ -2380,13 +2397,14 @@ void do_gemm_team_batched_blocked_parallel(options_t options) { STATUS; if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( - options, - __do_gemm_parallel_batched); + options, __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched); return; } @@ -2396,11 +2414,13 @@ void do_gemm_team_vector_batched_parallel(options_t options) { __do_loop_and_invoke( options, __do_gemm_parallel_batched); + KokkosBatched::Algo::Gemm::Unblocked, + default_device>); else __do_loop_and_invoke( options, - __do_gemm_parallel_batched); return; } @@ -2411,12 +2431,15 @@ void do_gemm_team_simd_batched_parallel(options_t options) { if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( options, - __do_gemm_parallel_batched); + __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched); return; } @@ -2426,12 +2449,15 @@ void do_gemm_team_simd_batched_blocked_parallel(options_t options) { if (options.blas_args.batch_size_last_dim) __do_loop_and_invoke( options, - __do_gemm_parallel_batched); + __do_gemm_parallel_batched); else __do_loop_and_invoke( - options, __do_gemm_parallel_batched); + options, + __do_gemm_parallel_batched); return; } @@ -2439,15 +2465,15 @@ void do_gemm_team_simd_batched_blocked_parallel(options_t options) { /* void do_gemm_team_vector_batched_blocked_parallel(options_t options) { STATUS; __do_loop_and_invoke( - options, __do_gemm_parallel_batched); return; + options, __do_gemm_parallel_batched); return; } */ void do_gemm_experiment_parallel(options_t options) { STATUS; - using TransAType = Trans::NoTranspose; - using TransBType = Trans::NoTranspose; - using BlockingType = Algo::Gemm::Unblocked; + using TransAType = KokkosBatched::Trans::NoTranspose; + using TransBType = KokkosBatched::Trans::NoTranspose; + using BlockingType = KokkosBatched::Algo::Gemm::Unblocked; // __do_loop_and_invoke( // options, __do_gemm_parallel_experiment1::invoke(trmm_args.alpha, A, B); + KokkosBatched::SerialTrmm::invoke( + trmm_args.alpha, A, B); } // Fence after submitting each batch operation Kokkos::fence(); @@ -291,7 +292,8 @@ void __do_trmm_serial_batched_template(options_t options, auto A = Kokkos::subview(trmm_args.A, i, Kokkos::ALL(), Kokkos::ALL()); auto B = Kokkos::subview(trmm_args.B, i, Kokkos::ALL(), Kokkos::ALL()); - SerialTrmm::invoke(trmm_args.alpha, A, B); + KokkosBatched::SerialTrmm::invoke( + trmm_args.alpha, A, B); } // Fence after submitting each batch operation Kokkos::fence(); @@ -315,6 +317,11 @@ void __do_trmm_serial_batched(options_t options, trmm_args_t trmm_args) { __trans = tolower(trmm_args.trans); //__diag = tolower(diag[0]); + using KokkosBatched::Diag; + using KokkosBatched::Side; + using KokkosBatched::Trans; + using KokkosBatched::Uplo; + STATUS; //// Lower non-transpose //// @@ -480,8 +487,8 @@ struct parallel_batched_trmm { auto svA = Kokkos::subview(trmm_args_.A, i, Kokkos::ALL(), Kokkos::ALL()); auto svB = Kokkos::subview(trmm_args_.B, i, Kokkos::ALL(), Kokkos::ALL()); - SerialTrmm::invoke(trmm_args_.alpha, svA, - svB); + KokkosBatched::SerialTrmm::invoke( + trmm_args_.alpha, svA, svB); } }; @@ -491,7 +498,7 @@ void __do_trmm_parallel_batched_template(options_t options, uint32_t warm_up_n = options.warm_up_n; uint32_t n = options.n; Kokkos::Timer timer; - using tag = Algo::Trmm::Unblocked; + using tag = KokkosBatched::Algo::Trmm::Unblocked; using execution_space = typename device_type::execution_space; using functor_type = parallel_batched_trmm; @@ -528,6 +535,11 @@ void __do_trmm_parallel_batched(options_t options, trmm_args_t trmm_args) { __trans = tolower(trmm_args.trans); //__diag = tolower(diag[0]); + using KokkosBatched::Diag; + using KokkosBatched::Side; + using KokkosBatched::Trans; + using KokkosBatched::Uplo; + STATUS; //// Lower non-transpose //// diff --git a/perf_test/blas/blas3/KokkosBlas_trtri_perf_test.hpp b/perf_test/blas/blas3/KokkosBlas_trtri_perf_test.hpp index 06b6d51455..f7bd1ef4f6 100644 --- a/perf_test/blas/blas3/KokkosBlas_trtri_perf_test.hpp +++ b/perf_test/blas/blas3/KokkosBlas_trtri_perf_test.hpp @@ -249,13 +249,13 @@ void __do_trtri_serial_batched_template(options_t options, uint32_t warm_up_n = options.warm_up_n; uint32_t n = options.n; Kokkos::Timer timer; - using tag = Algo::Trtri::Unblocked; + using tag = KokkosBatched::Algo::Trtri::Unblocked; for (uint32_t j = 0; j < warm_up_n; ++j) { for (int i = 0; i < options.start.a.k; ++i) { auto A = Kokkos::subview(trtri_args.A, i, Kokkos::ALL(), Kokkos::ALL()); - SerialTrtri::invoke(A); + KokkosBatched::SerialTrtri::invoke(A); } // Fence after each batch operation Kokkos::fence(); @@ -266,7 +266,7 @@ void __do_trtri_serial_batched_template(options_t options, for (int i = 0; i < options.start.a.k; ++i) { auto A = Kokkos::subview(trtri_args.A, i, Kokkos::ALL(), Kokkos::ALL()); - SerialTrtri::invoke(A); + KokkosBatched::SerialTrtri::invoke(A); } // Fence after each batch operation Kokkos::fence(); @@ -284,6 +284,9 @@ void __do_trtri_serial_batched_template(options_t /*options*/, template void __do_trtri_serial_batched(options_t options, trtri_args_t trtri_args) { + using KokkosBatched::Diag; + using KokkosBatched::Uplo; + char __uplo = tolower(trtri_args.uplo), __diag = tolower(trtri_args.diag); STATUS; @@ -382,7 +385,7 @@ struct parallel_batched_trtri { void operator()(const int& i) const { auto svA = Kokkos::subview(trtri_args_.A, i, Kokkos::ALL(), Kokkos::ALL()); - SerialTrtri::invoke(svA); + KokkosBatched::SerialTrtri::invoke(svA); } }; @@ -392,7 +395,7 @@ void __do_trtri_parallel_batched_template(options_t options, uint32_t warm_up_n = options.warm_up_n; uint32_t n = options.n; Kokkos::Timer timer; - using tag = Algo::Trtri::Unblocked; + using tag = KokkosBatched::Algo::Trtri::Unblocked; using execution_space = typename device_type::execution_space; using functor_type = parallel_batched_trtri; functor_type parallel_batched_trtri_functor(trtri_args); @@ -425,6 +428,9 @@ void __do_trtri_parallel_batched_template(options_t options, template void __do_trtri_parallel_batched(options_t options, trtri_args_t trtri_args) { + using KokkosBatched::Diag; + using KokkosBatched::Uplo; + char __uplo = tolower(trtri_args.uplo), __diag = tolower(trtri_args.diag); STATUS; diff --git a/src/blas/impl/KokkosBlas3_trmm_impl.hpp b/src/blas/impl/KokkosBlas3_trmm_impl.hpp index ce47023061..56bc2ba806 100644 --- a/src/blas/impl/KokkosBlas3_trmm_impl.hpp +++ b/src/blas/impl/KokkosBlas3_trmm_impl.hpp @@ -58,8 +58,6 @@ #include "KokkosBatched_Trmm_Decl.hpp" #include "KokkosBatched_Trmm_Serial_Impl.hpp" -using namespace KokkosBatched; - namespace KokkosBlas { namespace Impl { @@ -68,6 +66,13 @@ void SerialTrmm_Invoke(const char side[], const char uplo[], const char trans[], const char /*diag*/[], typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) { + using KokkosBatched::Algo; + using KokkosBatched::Diag; + using KokkosBatched::SerialTrmmInternalLeftLower; + using KokkosBatched::SerialTrmmInternalLeftUpper; + using KokkosBatched::SerialTrmmInternalRightLower; + using KokkosBatched::SerialTrmmInternalRightUpper; + char __side = tolower(side[0]), __uplo = tolower(uplo[0]), __trans = tolower(trans[0]); //__diag = tolower(diag[0]); diff --git a/src/blas/impl/KokkosBlas3_trsm_impl.hpp b/src/blas/impl/KokkosBlas3_trsm_impl.hpp index 64d35343b3..b215633093 100644 --- a/src/blas/impl/KokkosBlas3_trsm_impl.hpp +++ b/src/blas/impl/KokkosBlas3_trsm_impl.hpp @@ -57,8 +57,6 @@ #include "KokkosBatched_Trsm_Decl.hpp" #include "KokkosBatched_Trsm_Serial_Impl.hpp" -using namespace KokkosBatched; - namespace KokkosBlas { namespace Impl { @@ -74,9 +72,10 @@ int SerialTrsmInternalLeftLowerConj(const bool use_unit_diag, const int m, const ScalarType one(1.0), zero(0.0); if (alpha == zero) - SerialSetInternal ::invoke(m, n, zero, B, bs0, bs1); + KokkosBatched::SerialSetInternal ::invoke(m, n, zero, B, bs0, bs1); else { - if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); + if (alpha != one) + KokkosBatched::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); if (m <= 0 || n <= 0) return 0; for (int p = 0; p < m; ++p) { @@ -112,9 +111,10 @@ int SerialTrsmInternalLeftUpperConj(const bool use_unit_diag, const int m, const ScalarType one(1.0), zero(0.0); if (alpha == zero) - SerialSetInternal ::invoke(m, n, zero, B, bs0, bs1); + KokkosBatched::SerialSetInternal ::invoke(m, n, zero, B, bs0, bs1); else { - if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); + if (alpha != one) + KokkosBatched::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1); if (m <= 0 || n <= 0) return 0; ValueType* KOKKOS_RESTRICT B0 = B; @@ -145,19 +145,22 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], const char diag[], typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) { + using KokkosBatched::Algo; + using KokkosBatched::Diag; + // Side::Left, Uplo::Lower, Trans::NoTranspose if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::Unit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::NonUnit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1)); @@ -166,14 +169,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::Unit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::NonUnit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1)); @@ -198,14 +201,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::Unit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::NonUnit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(0), B.stride(1)); @@ -214,14 +217,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::Unit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1)); if (((side[0] == 'L') || (side[0] == 'l')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::NonUnit::use_unit_diag, B.extent(0), B.extent(1), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(0), B.stride(1)); @@ -246,14 +249,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::Unit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::NonUnit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0)); @@ -262,14 +265,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::Unit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'L') || (uplo[0] == 'l')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::NonUnit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0)); @@ -294,14 +297,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::Unit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'N') || (trans[0] == 'n')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftLower::invoke( + KokkosBatched::SerialTrsmInternalLeftLower::invoke( Diag::NonUnit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(1), A.stride(0), B.data(), B.stride(1), B.stride(0)); @@ -310,14 +313,14 @@ void SerialTrsm_Invoke(const char side[], const char uplo[], const char trans[], ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'U') || (diag[0] == 'u'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::Unit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0)); if (((side[0] == 'R') || (side[0] == 'r')) && ((uplo[0] == 'U') || (uplo[0] == 'u')) && ((trans[0] == 'T') || (trans[0] == 't')) && ((diag[0] == 'N') || (diag[0] == 'n'))) - SerialTrsmInternalLeftUpper::invoke( + KokkosBatched::SerialTrsmInternalLeftUpper::invoke( Diag::NonUnit::use_unit_diag, B.extent(1), B.extent(0), alpha, A.data(), A.stride(0), A.stride(1), B.data(), B.stride(1), B.stride(0)); diff --git a/src/blas/impl/KokkosBlas_trtri_impl.hpp b/src/blas/impl/KokkosBlas_trtri_impl.hpp index 18ef37a253..5aa82e6480 100644 --- a/src/blas/impl/KokkosBlas_trtri_impl.hpp +++ b/src/blas/impl/KokkosBlas_trtri_impl.hpp @@ -55,14 +55,17 @@ #include "KokkosBatched_Trtri_Decl.hpp" #include "KokkosBatched_Trtri_Serial_Impl.hpp" -using namespace KokkosBatched; - namespace KokkosBlas { namespace Impl { template void SerialTrtri_Invoke(const RViewType &R, const char uplo[], const char diag[], const AViewType &A) { + using KokkosBatched::Algo; + using KokkosBatched::Diag; + using KokkosBatched::SerialTrtriInternalLower; + using KokkosBatched::SerialTrtriInternalUpper; + char __uplo = tolower(uplo[0]), __diag = tolower(diag[0]); //// Lower ////