-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
845 additions
and
414 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// | ||
// Distributed Linear Algebra with Future (DLAF) | ||
// | ||
// Copyright (c) 2018-2021, ETH Zurich | ||
// All rights reserved. | ||
// | ||
// Please, refer to the LICENSE file in the root directory. | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
// | ||
|
||
#pragma once | ||
|
||
/// @file | ||
|
||
#ifdef DLAF_WITH_CUDA | ||
|
||
#include <cstddef> | ||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include <cublas_v2.h> | ||
#include <cuda_runtime.h> | ||
|
||
#include <hpx/local/runtime.hpp> | ||
|
||
#include "dlaf/common/assert.h" | ||
#include "dlaf/cublas/error.h" | ||
#include "dlaf/cuda/error.h" | ||
#include "dlaf/cuda/executor.h" | ||
|
||
namespace dlaf { | ||
namespace cublas { | ||
namespace internal { | ||
class HandlePoolImpl { | ||
int device_; | ||
std::size_t num_worker_threads_ = hpx::get_num_worker_threads(); | ||
std::vector<cublasHandle_t> handles_; | ||
cublasPointerMode_t ptr_mode_; | ||
|
||
public: | ||
HandlePoolImpl(int device, cublasPointerMode_t ptr_mode) | ||
: device_(device), handles_(num_worker_threads_), ptr_mode_(ptr_mode) { | ||
DLAF_CUDA_CALL(cudaSetDevice(device_)); | ||
|
||
for (auto& h : handles_) { | ||
DLAF_CUBLAS_CALL(cublasCreate(&h)); | ||
} | ||
} | ||
|
||
HandlePoolImpl& operator=(HandlePoolImpl&&) = default; | ||
HandlePoolImpl(HandlePoolImpl&&) = default; | ||
HandlePoolImpl(const HandlePoolImpl&) = delete; | ||
HandlePoolImpl& operator=(const HandlePoolImpl&) = delete; | ||
|
||
~HandlePoolImpl() { | ||
for (auto& h : handles_) { | ||
DLAF_CUBLAS_CALL(cublasDestroy(h)); | ||
} | ||
} | ||
|
||
cublasHandle_t getNextHandle(cudaStream_t stream) { | ||
cublasHandle_t handle = handles_[hpx::get_worker_thread_num()]; | ||
DLAF_CUDA_CALL(cudaSetDevice(device_)); | ||
DLAF_CUBLAS_CALL(cublasSetStream(handle, stream)); | ||
DLAF_CUBLAS_CALL(cublasSetPointerMode(handle, ptr_mode_)); | ||
return handle; | ||
} | ||
|
||
int getDevice() { | ||
return device_; | ||
} | ||
}; | ||
} | ||
|
||
/// A pool of cuBLAS handles with reference semantics (copying points to the | ||
/// same underlying cuBLAS handles, last reference destroys the references). | ||
/// Allows access to cuBLAS handles associated with a particular stream. The | ||
/// user must ensure that the handle pool and the stream use the same device. | ||
/// Each HPX worker thread is assigned thread local cuBLAS handle. | ||
class HandlePool { | ||
std::shared_ptr<internal::HandlePoolImpl> handles_ptr_; | ||
|
||
public: | ||
HandlePool(int device = 0, cublasPointerMode_t ptr_mode = CUBLAS_POINTER_MODE_HOST) | ||
: handles_ptr_(std::make_shared<internal::HandlePoolImpl>(device, ptr_mode)) {} | ||
|
||
cublasHandle_t getNextHandle(cudaStream_t stream) { | ||
DLAF_ASSERT(bool(handles_ptr_), ""); | ||
return handles_ptr_->getNextHandle(stream); | ||
} | ||
|
||
int getDevice() { | ||
DLAF_ASSERT(bool(handles_ptr_), ""); | ||
return handles_ptr_->getDevice(); | ||
} | ||
|
||
bool operator==(HandlePool const& rhs) const noexcept { | ||
return handles_ptr_ == rhs.handles_ptr_; | ||
} | ||
|
||
bool operator!=(HandlePool const& rhs) const noexcept { | ||
return !(*this == rhs); | ||
} | ||
}; | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.