Skip to content

Commit

Permalink
Introduce CUDA Device Assertions Infrastructure (pytorch#84609)
Browse files Browse the repository at this point in the history
Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Pull Request resolved: pytorch#84609

Reviewed By: ezyang

Differential Revision: D37621532

Pulled By: r-barnes

fbshipit-source-id: eacd53618c190f6d76caf2ab3928dfd68d92a85e
  • Loading branch information
r-barnes authored and facebook-github-bot committed Dec 6, 2022
1 parent 97e47a5 commit d782613
Show file tree
Hide file tree
Showing 18 changed files with 1,371 additions and 22 deletions.
7 changes: 4 additions & 3 deletions c10/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,25 @@ configure_file(
# and headers you add
set(C10_CUDA_SRCS
CUDACachingAllocator.cpp
CUDADeviceAssertionHost.cpp
CUDAException.cpp
CUDAFunctions.cpp
CUDAMallocAsyncAllocator.cpp
CUDAMiscFunctions.cpp
CUDAStream.cpp
CUDACachingAllocator.cpp
CUDAMallocAsyncAllocator.cpp
impl/CUDAGuardImpl.cpp
impl/CUDATest.cpp
)
set(C10_CUDA_HEADERS
CUDACachingAllocator.h
CUDADeviceAssertionHost.h
CUDAException.h
CUDAFunctions.h
CUDAGuard.h
CUDAMacros.h
CUDAMathCompat.h
CUDAMiscFunctions.h
CUDAStream.h
CUDACachingAllocator.h
impl/CUDAGuardImpl.h
impl/CUDATest.h
)
Expand Down
98 changes: 98 additions & 0 deletions c10/cuda/CUDADeviceAssertion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once

#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>

namespace c10 {
namespace cuda {

#ifdef TORCH_USE_CUDA_DSA
// Copy string from `src` to `dst`
static __device__ void dstrcpy(char* dst, const char* src) {
int i = 0;
// Copy string from source to destination, ensuring that it
// isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
*dst++ = *src++;
}
*dst = '\0';
}

__device__ __noinline__ void dsa_add_new_assertion_failure(
DeviceAssertionsData* assertions_data,
const char* assertion_msg,
const char* filename,
const char* function_name,
const int line_number,
const uint32_t caller,
const dim3 block_id,
const dim3 thread_id) {
// `assertions_data` may be nullptr if device-side assertion checking
// is disabled at run-time. If it is disabled at compile time this
// function will never be called
if (!assertions_data) {
return;
}

// Atomically increment so other threads can fail at the same time
// Note that incrementing this means that the CPU can observe that
// a failure has happened and can begin to respond before we've
// written information about that failure out to the buffer.
const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);

if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
// At this point we're ran out of assertion buffer space.
// We could print a message about this, but that'd get
// spammy if a lot of threads did it, so we just silently
// ignore any other assertion failures. In most cases the
// failures will all probably be analogous anyway.
return;
}

// Write information about the assertion failure to memory.
// Note that this occurs only after the `assertion_count`
// increment broadcasts that there's been a problem.
auto& self = assertions_data->assertions[nid];
dstrcpy(self.assertion_msg, assertion_msg);
dstrcpy(self.filename, filename);
dstrcpy(self.function_name, function_name);
self.line_number = line_number;
self.caller = caller;
self.block_id[0] = block_id.x;
self.block_id[1] = block_id.y;
self.block_id[2] = block_id.z;
self.thread_id[0] = thread_id.x;
self.thread_id[1] = thread_id.y;
self.thread_id[2] = thread_id.z;
}

// Emulates a kernel assertion. The assertion won't stop the kernel's progress,
// so you should assume everything the kernel produces is garbage if there's an
// assertion failure.
// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
// arguments of the kernel and therefore accessible.
#define CUDA_KERNEL_ASSERT2(condition) \
do { \
if (C10_UNLIKELY(!(condition))) { \
/* Has an atomic element so threads can fail at the same time */ \
c10::cuda::dsa_add_new_assertion_failure( \
assertions_data, \
C10_STRINGIZE(condition), \
__FILE__, \
__FUNCTION__, \
__LINE__, \
assertion_caller_id, \
blockIdx, \
threadIdx); \
/* Now that the kernel has failed we early exit the kernel, but */ \
/* otherwise keep going and rely on the host to check UVM and */ \
/* determine we've had a problem */ \
return; \
} \
} while (false)
#else
#define CUDA_KERNEL_ASSERT2(condition) assert(condition)
#endif

} // namespace cuda
} // namespace c10
Loading

0 comments on commit d782613

Please sign in to comment.