Skip to content

Commit

Permalink
TPL singletons: allow query of whether initialized
Browse files Browse the repository at this point in the history
And test KokkosKernels::eager_initialize() using this
  • Loading branch information
brian-kelley committed Sep 28, 2024
1 parent 6b656c8 commit 6ea703c
Show file tree
Hide file tree
Showing 16 changed files with 256 additions and 56 deletions.
18 changes: 15 additions & 3 deletions blas/tpls/KokkosBlas_Cuda_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ namespace Impl {
CudaBlasSingleton::CudaBlasSingleton() {
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) Kokkos::abort("CUBLAS initialization failed\n");

Kokkos::push_finalize_hook([&]() { cublasDestroy(handle); });
}

CudaBlasSingleton& CudaBlasSingleton::singleton() {
static CudaBlasSingleton s;
std::unique_ptr<CudaBlasSingleton>& instance = get_instance();
if (!instance) {
instance = std::make_unique<CudaBlasSingleton>();
Kokkos::push_finalize_hook([&]() {
cublasDestroy(instance->handle);
instance.reset();
});
}
return *instance;
}

bool CudaBlasSingleton::is_initialized() { return get_instance() != nullptr; }

std::unique_ptr<CudaBlasSingleton>& CudaBlasSingleton::get_instance() {
static std::unique_ptr<CudaBlasSingleton> s;
return s;
}

Expand Down
18 changes: 15 additions & 3 deletions blas/tpls/KokkosBlas_Magma_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ namespace Impl {
MagmaSingleton::MagmaSingleton() {
magma_int_t stat = magma_init();
if (stat != MAGMA_SUCCESS) Kokkos::abort("MAGMA initialization failed\n");

Kokkos::push_finalize_hook([&]() { magma_finalize(); });
}

MagmaSingleton& MagmaSingleton::singleton() {
static MagmaSingleton s;
std::unique_ptr<MagmaSingleton>& instance = get_instance();
if (!instance) {
instance = std::make_unique<MagmaSingleton>();
Kokkos::push_finalize_hook([&]() {
magma_finalize();
instance.reset();
});
}
return *instance;
}

bool MagmaSingleton::is_initialized() { return get_instance() != nullptr; }

std::unique_ptr<MagmaSingleton>& MagmaSingleton::get_instance() {
static std::unique_ptr<MagmaSingleton> s;
return s;
}

Expand Down
20 changes: 15 additions & 5 deletions blas/tpls/KokkosBlas_Rocm_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,24 @@
namespace KokkosBlas {
namespace Impl {

RocBlasSingleton::RocBlasSingleton() {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_create_handle(&handle));
RocBlasSingleton::RocBlasSingleton() { KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_create_handle(&handle)); }

Kokkos::push_finalize_hook([&]() { KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_destroy_handle(handle)); });
RocBlasSingleton& RocBlasSingleton::singleton() {
std::unique_ptr<RocBlasSingleton>& instance = get_instance();
if (!instance) {
instance = std::make_unique<RocBlasSingleton>();
Kokkos::push_finalize_hook([&]() {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_destroy_handle(instance->handle));
instance.reset();
});
}
return *instance;
}

RocBlasSingleton& RocBlasSingleton::singleton() {
static RocBlasSingleton s;
bool RocBlasSingleton::is_initialized() { return get_instance() != nullptr; }

std::unique_ptr<RocBlasSingleton>& RocBlasSingleton::get_instance() {
static std::unique_ptr<RocBlasSingleton> s;
return s;
}

Expand Down
4 changes: 4 additions & 0 deletions blas/tpls/KokkosBlas_magma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ namespace Impl {
struct MagmaSingleton {
MagmaSingleton();

static bool is_initialized();
static MagmaSingleton& singleton();

private:
static std::unique_ptr<MagmaSingleton>& get_instance();
};

} // namespace Impl
Expand Down
9 changes: 9 additions & 0 deletions blas/tpls/KokkosBlas_tpl_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ struct CudaBlasSingleton {

CudaBlasSingleton();

static bool is_initialized();
static CudaBlasSingleton& singleton();

private:
static std::unique_ptr<CudaBlasSingleton>& get_instance();
};

inline void cublas_internal_error_throw(cublasStatus_t cublasState, const char* name, const char* file,
Expand Down Expand Up @@ -111,7 +115,12 @@ struct RocBlasSingleton {

RocBlasSingleton();

static bool is_initialized();

static RocBlasSingleton& singleton();

private:
static std::unique_ptr<RocBlasSingleton>& get_instance();
};

inline void rocblas_internal_error_throw(rocblas_status rocblasState, const char* name, const char* file,
Expand Down
1 change: 0 additions & 1 deletion common/src/KokkosKernels_EagerInitialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ void eager_initialize() {
(void)KokkosBlas::Impl::RocBlasSingleton::singleton();
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA
#include "KokkosBlas_Magma_tpl.hpp"
(void)KokkosBlas::Impl::MagmaSingleton::singleton();
#endif
#endif
Expand Down
7 changes: 7 additions & 0 deletions common/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,10 @@ IF (KOKKOS_ENABLE_THREADS)
)
ENDIF ()

# Add eager_initialize test, which is not backend-specific
KOKKOSKERNELS_ADD_UNIT_TEST(
common_eager_initialize
SOURCES Test_Common_EagerInitialize.cpp
COMPONENTS common
)

1 change: 0 additions & 1 deletion common/unit_test/Test_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,5 @@
#include <Test_Common_Iota.hpp>
#include <Test_Common_LowerBound.hpp>
#include <Test_Common_UpperBound.hpp>
#include <Test_Common_EagerInitialize.hpp>

#endif // TEST_COMMON_HPP
113 changes: 113 additions & 0 deletions common/unit_test/Test_Common_EagerInitialize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KK_EAGERINIT_TEST_HPP
#define KK_EAGERINIT_TEST_HPP

#include <iostream>
#include "Kokkos_Core.hpp"
#include "KokkosKernels_config.h"
#include "KokkosKernels_EagerInitialize.hpp"

#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_BLAS
#include "KokkosBlas_tpl_spec.hpp" //cuBLAS, rocBLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA
#include "KokkosBlas_Magma_tpl.hpp"
#endif
#endif

#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_SPARSE
// note: this file declares both cuSPARSE and rocSPARSE singletons
#include "KokkosKernels_tpl_handles_decl.hpp"
#endif

#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_LAPACK
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSOLVER
#include "KokkosLapack_cusolver.hpp"
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA
#include "KokkosLapack_magma.hpp"
#endif
#endif

// Count the number of singletons which are currently initialized,
// and the numInitialized number of singleton classes that are currently enabled
// (based on which TPLs and components were enabled at configure-time)
void countSingletons(int& numInitialized, int& numEnabled) {
numInitialized = 0;
numEnabled = 0;
#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_BLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
numEnabled++;
if (KokkosBlas::Impl::CudaBlasSingleton::is_initialized()) numInitialized++;
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS
numEnabled++;
if (KokkosBlas::Impl::RocBlasSingleton::is_initialized()) numInitialized++;
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA
numEnabled++;
if (KokkosBlas::Impl::MagmaSingleton::is_initialized()) numInitialized++;
#endif
#endif

#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_SPARSE
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE
numEnabled++;
if (KokkosKernels::Impl::CusparseSingleton::is_initialized()) numInitialized++;
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
numEnabled++;
if (KokkosKernels::Impl::RocsparseSingleton::is_initialized()) numInitialized++;
#endif
#endif

#ifdef KOKKOSKERNELS_ENABLE_COMPONENT_LAPACK
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSOLVER
numEnabled++;
if (KokkosLapack::Impl::CudaLapackSingleton::is_initialized()) numInitialized++;
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA
numEnabled++;
if (KokkosLapack::Impl::MagmaSingleton::is_initialized()) numInitialized++;
#endif
#endif
}

int main() {
int numInitialized, numEnabled;
Kokkos::initialize();
{
// Check that no singletons are already initialized.
countSingletons(numInitialized, numEnabled);
if (numInitialized != 0)
throw std::runtime_error("At least one singleton was initialized before it should have been");
KokkosKernels::eager_initialize();
// Check that all singletons are now initialized.
countSingletons(numInitialized, numEnabled);
std::cout << "Kokkos::eager_initialize() set up " << numInitialized << " of " << numEnabled << " TPL singletons.\n";
if (numInitialized != numEnabled)
throw std::runtime_error("At least one singleton was not initialized by eager_initialize()");
}
Kokkos::finalize();
// Finally, make sure that all singletons were finalized during Kokkos::finalize().
countSingletons(numInitialized, numEnabled);
if (numInitialized != 0)
throw std::runtime_error("At least one singleton was not correctly finalized by Kokkos::finalize()");
return 0;
}

#endif
27 changes: 0 additions & 27 deletions common/unit_test/Test_Common_EagerInitialize.hpp

This file was deleted.

18 changes: 15 additions & 3 deletions lapack/tpls/KokkosLapack_Cuda_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ namespace Impl {
CudaLapackSingleton::CudaLapackSingleton() {
cusolverStatus_t stat = cusolverDnCreate(&handle);
if (stat != CUSOLVER_STATUS_SUCCESS) Kokkos::abort("CUSOLVER initialization failed\n");

Kokkos::push_finalize_hook([&]() { cusolverDnDestroy(handle); });
}

CudaLapackSingleton& CudaLapackSingleton::singleton() {
static CudaLapackSingleton s;
std::unique_ptr<CudaLapackSingleton>& instance = get_instance();
if (!instance) {
instance = std::make_unique<CudaLapackSingleton>();
Kokkos::push_finalize_hook([&]() {
cusolverDnDestroy(instance->handle);
instance.reset();
});
}
return *instance;
}

bool CudaLapackSingleton::is_initialized() { return get_instance() != nullptr; }

std::unique_ptr<CudaLapackSingleton>& CudaLapackSingleton::get_instance() {
static std::unique_ptr<CudaLapackSingleton> s;
return s;
}

Expand Down
18 changes: 15 additions & 3 deletions lapack/tpls/KokkosLapack_Magma_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ namespace Impl {
MagmaSingleton::MagmaSingleton() {
magma_int_t stat = magma_init();
if (stat != MAGMA_SUCCESS) Kokkos::abort("MAGMA initialization failed\n");

Kokkos::push_finalize_hook([&]() { magma_finalize(); });
}

MagmaSingleton& MagmaSingleton::singleton() {
static MagmaSingleton s;
std::unique_ptr<MagmaSingleton>& instance = get_instance();
if (!instance) {
instance = std::make_unique<MagmaSingleton>();
Kokkos::push_finalize_hook([&]() {
magma_finalize();
instance.reset();
});
}
return *instance;
}

bool MagmaSingleton::is_initialized() { return get_instance() != nullptr; }

std::unique_ptr<MagmaSingleton>& MagmaSingleton::get_instance() {
static std::unique_ptr<MagmaSingleton> s;
return s;
}

Expand Down
5 changes: 5 additions & 0 deletions lapack/tpls/KokkosLapack_cusolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ struct CudaLapackSingleton {
CudaLapackSingleton();

static CudaLapackSingleton& singleton();

static bool is_initialized();

private:
static std::unique_ptr<CudaLapackSingleton>& get_instance();
};

inline void cusolver_internal_error_throw(cusolverStatus_t cusolverStatus, const char* name, const char* file,
Expand Down
5 changes: 5 additions & 0 deletions lapack/tpls/KokkosLapack_magma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ struct MagmaSingleton {
MagmaSingleton();

static MagmaSingleton& singleton();

static bool is_initialized();

private:
static std::unique_ptr<MagmaSingleton>& get_instance();
};

} // namespace Impl
Expand Down
8 changes: 8 additions & 0 deletions sparse/tpls/KokkosKernels_tpl_handles_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ struct CusparseSingleton {

CusparseSingleton();

static bool is_initialized();
static CusparseSingleton& singleton();

private:
static std::unique_ptr<CusparseSingleton>& get_instance();
};

} // namespace Impl
Expand All @@ -48,7 +52,11 @@ struct RocsparseSingleton {

RocsparseSingleton();

static bool is_initialized();
static RocsparseSingleton& singleton();

private:
static std::unique_ptr<RocsparseSingleton>& get_instance();
};

} // namespace Impl
Expand Down
Loading

0 comments on commit 6ea703c

Please sign in to comment.