diff --git a/HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h b/HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h index 9b296dd390ea3..d11baa0191d4f 100644 --- a/HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h +++ b/HeterogeneousCore/CUDAUtilities/interface/ScopedSetDevice.h @@ -9,20 +9,32 @@ namespace cms { namespace cuda { class ScopedSetDevice { public: - explicit ScopedSetDevice(int newDevice) { - cudaCheck(cudaGetDevice(&prevDevice_)); - cudaCheck(cudaSetDevice(newDevice)); + // Store the original device, without setting a new one + ScopedSetDevice() { + cudaCheck(cudaGetDevice(&originalDevice_)); } + // Store the original device, and set a new current device + explicit ScopedSetDevice(int device) : ScopedSetDevice() { + set(device); + } + + // Restore the original device ~ScopedSetDevice() { // Intentionally don't check the return value to avoid // exceptions to be thrown. If this call fails, the process is // doomed anyway. - cudaSetDevice(prevDevice_); + cudaSetDevice(originalDevice_); + } + + // Set a new current device, without changing the original device + // that will be restored when this object is destroyed + void set(int device) { + cudaCheck(cudaSetDevice(device)); } private: - int prevDevice_; + int originalDevice_; }; } // namespace cuda } // namespace cms