Skip to content

Commit

Permalink
Fixes #669: Can now obtain the number of kernels in a module and a co…
Browse files Browse the repository at this point in the history
…ntainer of `kernel_t` wrappers for them
  • Loading branch information
eyalroz committed Oct 26, 2024
1 parent 4654f95 commit 71472dc
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/cuda/api/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ ::std::string identify(const module_t &module);

inline void destroy(handle_t handle, context::handle_t context_handle, device::id_t device_id);

#if CUDA_VERSION >= 12040
inline unique_span<kernel::handle_t> get_kernel_handles(handle_t module_handle, size_t num_kernels)
{
auto result = make_unique_span<kernel::handle_t>(num_kernels);
auto status = cuModuleEnumerateFunctions(result.data(), num_kernels, module_handle);
throw_if_error_lazy(status, "Failed enumerating the kernels in " + module::detail_::identify(module_handle));
return result;
}
#endif

} // namespace detail_

/**
Expand Down Expand Up @@ -153,6 +163,25 @@ class module_t {
return { memory::as_pointer(dptr), size };
}

#if CUDA_VERSION >= 12040
size_t get_num_kernels() const
{
unsigned result;
auto status = cuModuleGetFunctionCount(&result, handle_);
throw_if_error_lazy(status, "Failed determining function count for " + module::detail_::identify(*this));
return result;
}

unique_span<kernel_t> get_kernels() const
{
auto num_kernels = get_num_kernels();
// It's ok if the number is 0!
auto handles = module::detail_::get_kernel_handles(handle_, num_kernels);
auto gen = [&](size_t i) { return kernel::wrap(device_id_, context_handle_, handles[i]); };
return generate_unique_span<kernel_t>(handles.size(), gen);
}
#endif // CUDA_VERSION >= 12040

// TODO: Implement a surface reference and texture reference class rather than these raw pointers.

#if CUDA_VERSION < 12000
Expand Down

0 comments on commit 71472dc

Please sign in to comment.