diff --git a/Changelog.md b/Changelog.md index fac6f5fbc..b0654193e 100644 --- a/Changelog.md +++ b/Changelog.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `VK_EXT_swapchain_maintenance1` device extension (#786) - Added `VK_NV_low_latency2` device extension (#802) - Added `VK_EXT_hdr_metadata` device extension (#804) +- Added `VK_NV_cuda_kernel_launch` device extension (#805) ### Changed diff --git a/ash/src/device.rs b/ash/src/device.rs index 817235f03..ecac1334d 100644 --- a/ash/src/device.rs +++ b/ash/src/device.rs @@ -2232,12 +2232,12 @@ impl Device { &self, pipeline_cache: vk::PipelineCache, ) -> VkResult> { - read_into_uninitialized_vector(|count, data| { + read_into_uninitialized_vector(|count, data: *mut u8| { (self.device_fn_1_0.get_pipeline_cache_data)( self.handle(), pipeline_cache, count, - data as _, + data.cast(), ) }) } diff --git a/ash/src/extensions/nv/cuda_kernel_launch.rs b/ash/src/extensions/nv/cuda_kernel_launch.rs new file mode 100644 index 000000000..cfc3227e7 --- /dev/null +++ b/ash/src/extensions/nv/cuda_kernel_launch.rs @@ -0,0 +1,107 @@ +use crate::prelude::*; +use crate::vk; +use crate::RawPtr; +use crate::{Device, Instance}; +use std::ffi::CStr; +use std::mem; + +/// +#[derive(Clone)] +pub struct CudaKernelLaunch { + handle: vk::Device, + fp: vk::NvCudaKernelLaunchFn, +} + +impl CudaKernelLaunch { + pub fn new(instance: &Instance, device: &Device) -> Self { + let handle = device.handle(); + let fp = vk::NvCudaKernelLaunchFn::load(|name| unsafe { + mem::transmute(instance.get_device_proc_addr(handle, name.as_ptr())) + }); + Self { handle, fp } + } + + /// + #[inline] + pub unsafe fn create_cuda_module( + &self, + create_info: &vk::CudaModuleCreateInfoNV, + allocator: Option<&vk::AllocationCallbacks>, + ) -> VkResult { + let mut module = mem::MaybeUninit::uninit(); + (self.fp.create_cuda_module_nv)( + self.handle, + create_info, + allocator.as_raw_ptr(), + module.as_mut_ptr(), + ) + .assume_init_on_success(module) + } + + /// + #[inline] + pub unsafe fn get_cuda_module_cache(&self, module: vk::CudaModuleNV) -> VkResult> { + read_into_uninitialized_vector(|cache_size, cache_data: *mut u8| { + (self.fp.get_cuda_module_cache_nv)(self.handle, module, cache_size, cache_data.cast()) + }) + } + + /// + #[inline] + pub unsafe fn create_cuda_function( + &self, + create_info: &vk::CudaFunctionCreateInfoNV, + allocator: Option<&vk::AllocationCallbacks>, + ) -> VkResult { + let mut function = mem::MaybeUninit::uninit(); + (self.fp.create_cuda_function_nv)( + self.handle, + create_info, + allocator.as_raw_ptr(), + function.as_mut_ptr(), + ) + .assume_init_on_success(function) + } + + /// + #[inline] + pub unsafe fn destroy_cuda_module( + &self, + module: vk::CudaModuleNV, + allocator: Option<&vk::AllocationCallbacks>, + ) { + (self.fp.destroy_cuda_module_nv)(self.handle, module, allocator.as_raw_ptr()) + } + + /// + #[inline] + pub unsafe fn destroy_cuda_function( + &self, + function: vk::CudaFunctionNV, + allocator: Option<&vk::AllocationCallbacks>, + ) { + (self.fp.destroy_cuda_function_nv)(self.handle, function, allocator.as_raw_ptr()) + } + + /// + #[inline] + pub unsafe fn cmd_cuda_launch_kernel( + &self, + command_buffer: vk::CommandBuffer, + launch_info: &vk::CudaLaunchInfoNV, + ) { + (self.fp.cmd_cuda_launch_kernel_nv)(command_buffer, launch_info) + } + + pub const NAME: &'static CStr = vk::NvCudaKernelLaunchFn::NAME; + + #[inline] + pub fn fp(&self) -> &vk::NvCudaKernelLaunchFn { + &self.fp + } + + #[inline] + pub fn device(&self) -> vk::Device { + self.handle + } +} diff --git a/ash/src/extensions/nv/mod.rs b/ash/src/extensions/nv/mod.rs index 3b08ae3b3..1a4f8dbdc 100644 --- a/ash/src/extensions/nv/mod.rs +++ b/ash/src/extensions/nv/mod.rs @@ -1,4 +1,5 @@ pub use self::coverage_reduction_mode::CoverageReductionMode; +pub use self::cuda_kernel_launch::CudaKernelLaunch; pub use self::device_diagnostic_checkpoints::DeviceDiagnosticCheckpoints; pub use self::device_generated_commands_compute::DeviceGeneratedCommandsCompute; pub use self::low_latency2::LowLatency2; @@ -7,6 +8,7 @@ pub use self::mesh_shader::MeshShader; pub use self::ray_tracing::RayTracing; mod coverage_reduction_mode; +mod cuda_kernel_launch; mod device_diagnostic_checkpoints; mod device_generated_commands_compute; mod low_latency2;