Skip to content

Commit

Permalink
WIP ray tracing backend
Browse files Browse the repository at this point in the history
  • Loading branch information
tangmi committed Jun 3, 2021
1 parent 6899794 commit 2b1fdf0
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 130 deletions.
20 changes: 11 additions & 9 deletions src/backend/vulkan/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,18 +1080,16 @@ impl com::CommandBuffer<Backend> for CommandBuffer {
desc: &'a hal::acceleration_structure::BuildDesc<'a, Backend>,
ranges: &'a [hal::acceleration_structure::BuildRangeDesc],
) {
let geometries = conv::map_geometries(&self.device, desc.geometry.geometries.iter());
self.device
.extension_fns
.acceleration_structure
.as_ref()
.expect("Feature ACCELERATION_STRUCTURE must be enabled to call build_acceleration_structure").unwrap_extension()
.cmd_build_acceleration_structures(
self.raw,
&[conv::map_geometry_info(&self.device, desc)],
&[mem::transmute::<
&[hal::acceleration_structure::BuildRangeDesc],
&[vk::AccelerationStructureBuildRangeInfoKHR],
>(ranges)],
&[conv::map_geometry_info_without_geometries(&self.device, desc).geometries(&geometries).build()],
&[conv::map_build_ranges_infos(ranges)],
);
}

Expand All @@ -1103,14 +1101,15 @@ impl com::CommandBuffer<Backend> for CommandBuffer {
stride: buffer::Stride,
max_primitive_counts: &'a [u32],
) {
let geometries = conv::map_geometries(&self.device, desc.geometry.geometries.iter());
self.device
.extension_fns
.acceleration_structure
.as_ref()
.expect("Feature ACCELERATION_STRUCTURE must be enabled to call build_acceleration_structure_indirect").unwrap_extension()
.cmd_build_acceleration_structures_indirect(
self.raw,
&[conv::map_geometry_info(&self.device, desc)],
&[conv::map_geometry_info_without_geometries(&self.device, desc).geometries(&geometries).build()],
&[self.device.get_buffer_device_address(buffer, offset)],
&[stride],
&[max_primitive_counts],
Expand Down Expand Up @@ -1228,7 +1227,8 @@ impl com::CommandBuffer<Backend> for CommandBuffer {
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("Feature ACCELERATION_STRUCTURE must be enabled to call set_ray_tracing_pipeline_stack_size")
.expect("Feature RAY_TRACING_PIPELINE must be enabled to call set_ray_tracing_pipeline_stack_size")
.unwrap_extension()
.cmd_set_ray_tracing_pipeline_stack_size(self.raw, pipeline_stack_size);
}

Expand All @@ -1244,7 +1244,8 @@ impl com::CommandBuffer<Backend> for CommandBuffer {
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("Feature ACCELERATION_STRUCTURE must be enabled to call trace_rays")
.expect("Feature RAY_TRACING_PIPELINE must be enabled to call trace_rays")
.unwrap_extension()
.cmd_trace_rays(
self.raw,
&conv::map_shader_binding_table(&self.device, raygen_shader_binding_table),
Expand All @@ -1271,7 +1272,8 @@ impl com::CommandBuffer<Backend> for CommandBuffer {
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("Feature ACCELERATION_STRUCTURE must be enabled to call trace_rays_indirect")
.expect("Feature RAY_TRACING_PIPELINE must be enabled to call trace_rays_indirect")
.unwrap_extension()
.cmd_trace_rays_indirect(
self.raw,
&[conv::map_shader_binding_table(
Expand Down
89 changes: 55 additions & 34 deletions src/backend/vulkan/src/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,10 +782,11 @@ pub unsafe fn map_geometries<'a>(
.collect::<Vec<_>>()
}

pub unsafe fn map_geometry_info(
/// Convert all fields of `desc`, except `geometries`. The caller should call `map_geometries` and add it to the builder manually to ensure the lifetime of the resulting collection lives long enough.
pub unsafe fn map_geometry_info_without_geometries<'a>(
device: &crate::RawDevice,
desc: &hal::acceleration_structure::BuildDesc<crate::Backend>,
) -> vk::AccelerationStructureBuildGeometryInfoKHR {
desc: &'a hal::acceleration_structure::BuildDesc<crate::Backend>,
) -> vk::AccelerationStructureBuildGeometryInfoKHRBuilder<'a> {
vk::AccelerationStructureBuildGeometryInfoKHR::builder()
.ty(map_acceleration_structure_type(desc.geometry.ty))
.flags(map_acceleration_structure_flags(desc.geometry.flags))
Expand All @@ -796,14 +797,16 @@ pub unsafe fn map_geometry_info(
})
.src_acceleration_structure(desc.src.map(|a| a.0).unwrap_or_default())
.dst_acceleration_structure(desc.dst.0)
.geometries(
// TODO: this is unsafe since the lifetime of this vec could be shorter than its caller?
map_geometries(device, desc.geometry.geometries.iter()).as_slice(),
)
.scratch_data(vk::DeviceOrHostAddressKHR {
device_address: device.get_buffer_device_address(desc.scratch, desc.scratch_offset),
})
.build()
}

pub unsafe fn map_build_ranges_infos(
build_ranges: &[hal::acceleration_structure::BuildRangeDesc],
) -> &[vk::AccelerationStructureBuildRangeInfoKHR] {
// Safe because `BuildRangeDesc` and `AccelerationStructureBuildRangeInfoKHR` have the same layout.
mem::transmute(build_ranges)
}

pub fn map_group_shader(group_shader: pso::GroupShader) -> vk::ShaderGroupShaderKHR {
Expand Down Expand Up @@ -913,39 +916,57 @@ pub fn map_shader_stage(stage: pso::ShaderStageFlags) -> vk::ShaderStageFlags {
flags
}

pub fn map_group_type(group_type: pso::GroupType) -> vk::RayTracingShaderGroupTypeKHR {
match group_type {
pso::GroupType::General => vk::RayTracingShaderGroupTypeKHR::GENERAL,
pso::GroupType::TrianglesHitGroup => vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP,
pso::GroupType::ProceduralHitGroup => {
vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP
}
}
}

pub fn map_shader_group_desc(
desc: &pso::ShaderGroupDesc,
) -> vk::RayTracingShaderGroupCreateInfoKHR {
vk::RayTracingShaderGroupCreateInfoKHR::builder()
.ty(map_group_type(desc.ty))
.general_shader(desc.general_shader)
.closest_hit_shader(desc.closest_hit_shader)
.any_hit_shader(desc.any_hit_shader)
.intersection_shader(desc.intersection_shader)
.build()
match desc {
pso::ShaderGroupDesc::General { general_shader } => {
vk::RayTracingShaderGroupCreateInfoKHR::builder()
.ty(vk::RayTracingShaderGroupTypeKHR::GENERAL)
.general_shader(*general_shader)
.closest_hit_shader(vk::SHADER_UNUSED_KHR)
.any_hit_shader(vk::SHADER_UNUSED_KHR)
.intersection_shader(vk::SHADER_UNUSED_KHR)
.build()
}

pso::ShaderGroupDesc::TrianglesHitGroup {
closest_hit_shader,
any_hit_shader,
} => vk::RayTracingShaderGroupCreateInfoKHR::builder()
.ty(vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP)
.general_shader(vk::SHADER_UNUSED_KHR)
.closest_hit_shader(closest_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
.any_hit_shader(any_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
.intersection_shader(vk::SHADER_UNUSED_KHR)
.build(),

pso::ShaderGroupDesc::ProceduralHitGroup {
closest_hit_shader,
any_hit_shader,
intersection_shader,
} => vk::RayTracingShaderGroupCreateInfoKHR::builder()
.ty(vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP)
.general_shader(vk::SHADER_UNUSED_KHR)
.closest_hit_shader(closest_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
.any_hit_shader(any_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
.intersection_shader(*intersection_shader)
.build(),
}
}

pub unsafe fn map_shader_binding_table(
device: &crate::RawDevice,
table: Option<pso::ShaderBindingTable<crate::Backend>>,
) -> vk::StridedDeviceAddressRegionKHR {
if let Some(table) = table {
vk::StridedDeviceAddressRegionKHR::builder()
.device_address(device.get_buffer_device_address(table.buffer, table.offset))
.stride(table.stride as u64)
.size(table.size)
.build()
} else {
vk::StridedDeviceAddressRegionKHR::default()
}
table.map_or_else(
|| vk::StridedDeviceAddressRegionKHR::default(),
|table| {
vk::StridedDeviceAddressRegionKHR::builder()
.device_address(device.get_buffer_device_address(table.buffer, table.offset))
.stride(table.stride as u64)
.size(table.size)
.build()
},
)
}
99 changes: 49 additions & 50 deletions src/backend/vulkan/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,11 @@ impl<'a> RayTracingPipelineInfoBuf<'a> {
this.shader_groups = desc
.stages
.iter()
.map(|(_stage, entry_point)| {
.map(|stage_desc| {
let mut buf = ComputePipelineInfoBuf::default();
buf.c_string = CString::new(entry_point.entry).unwrap();
buf.entries = entry_point
buf.c_string = CString::new(stage_desc.entry_point.entry).unwrap();
buf.entries = stage_desc
.entry_point
.specialization
.constants
.iter()
Expand All @@ -479,8 +480,8 @@ impl<'a> RayTracingPipelineInfoBuf<'a> {
buf.specialization = vk::SpecializationInfo {
map_entry_count: buf.entries.len() as _,
p_map_entries: buf.entries.as_ptr(),
data_size: entry_point.specialization.data.len() as _,
p_data: entry_point.specialization.data.as_ptr() as _,
data_size: stage_desc.entry_point.specialization.data.len() as _,
p_data: stage_desc.entry_point.specialization.data.as_ptr() as _,
};
buf
})
Expand Down Expand Up @@ -927,6 +928,26 @@ impl d::Device<B> for super::Device {
) -> Result<n::RayTracingPipeline, pso::CreationError> {
let buf = RayTracingPipelineInfoBuf::new(desc);

let stages = desc
.stages
.iter()
.zip(&buf.shader_groups)
.map(|(stage_desc, buf)| {
vk::PipelineShaderStageCreateInfo::builder()
.flags(vk::PipelineShaderStageCreateFlags::empty())
.stage(conv::map_shader_stage(stage_desc.stage))
.module(stage_desc.entry_point.module.raw)
.name(buf.c_string.as_c_str())
.specialization_info(&buf.specialization)
.build()
})
.collect::<Vec<_>>();
let groups = desc
.groups
.iter()
.map(conv::map_shader_group_desc)
.collect::<Vec<_>>();

let info = {
let (base_handle, base_index) = match desc.parent {
pso::BasePipeline::Pipeline(pipeline) => (pipeline.0, -1),
Expand All @@ -936,37 +957,15 @@ impl d::Device<B> for super::Device {

vk::RayTracingPipelineCreateInfoKHR::builder()
.flags(conv::map_pipeline_create_flags(desc.flags, &desc.parent))
.stages(
desc.stages
.iter()
.zip(&buf.shader_groups)
.map(|((stage, entry), buf)| {
vk::PipelineShaderStageCreateInfo::builder()
.flags(vk::PipelineShaderStageCreateFlags::empty())
.stage(conv::map_shader_stage(*stage))
.module(entry.module.raw)
.name(buf.c_string.as_c_str())
.specialization_info(&buf.specialization)
.build()
})
.collect::<Vec<_>>()
.as_slice(),
)
.groups(
desc.groups
.iter()
.map(conv::map_shader_group_desc)
.collect::<Vec<_>>()
.as_slice(),
)
.stages(&stages)
.groups(&groups)
.max_pipeline_ray_recursion_depth(desc.max_pipeline_ray_recursion_depth)
// .library_info()
// .library_interface()
// .dynamic_state()
.layout(desc.layout.raw)
.base_pipeline_handle(base_handle)
.base_pipeline_index(base_index)
.build()
};

// TODO create_ray_tracing_pipelines also returns VK_OPERATION_DEFERRED_KHR, VK_OPERATION_NOT_DEFERRED_KHR, VK_PIPELINE_COMPILE_REQUIRED_EXT on success, but ash does not support this.
Expand All @@ -975,11 +974,14 @@ impl d::Device<B> for super::Device {
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("TODO msg")
.expect(
"Feature RAY_TRACING_PIPELINE must be enabled to call create_ray_tracing_pipeline",
)
.unwrap_extension()
.create_ray_tracing_pipelines(
vk::DeferredOperationKHR::null(),
cache.map_or(vk::PipelineCache::null(), |cache| cache.raw),
&[info],
&[info.build()],
None,
) {
Ok(pipelines) => {
Expand Down Expand Up @@ -2012,24 +2014,21 @@ impl d::Device<B> for super::Device {
pipeline: &'a n::RayTracingPipeline,
first_group: u32,
group_count: u32,
data: &mut [u8],
) -> Result<(), d::OutOfMemory> {
// TODO: data_size? ash returns a vec<>, but vulkan takes a pointer. either way needs data_size, which must be user-provided based on the physical device limits
// let result = self
// .shared
// .extension_fns
// .ray_tracing_pipeline
// .as_ref()
// .expect("TODO msg")
// .get_ray_tracing_shader_group_handles(pipeline.0, first_group, group_count);

// match result {
// Ok(_) => Ok(()),
// Err(vk::Result::ERROR_OUT_OF_HOST_MEMORY) => Err(d::OutOfMemory::Host),
// _ => unreachable!(),
// }
data_size: usize,
) -> Result<Vec<u8>, d::OutOfMemory> {
let result = self
.shared
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("Feature RAY_TRACING_PIPELINE must be enabled to call get_ray_tracing_shader_group_handles").unwrap_extension()
.get_ray_tracing_shader_group_handles(pipeline.0, first_group, group_count, data_size);

todo!()
match result {
Ok(data) => Ok(data),
Err(vk::Result::ERROR_OUT_OF_HOST_MEMORY) => Err(d::OutOfMemory::Host),
_ => unreachable!(),
}
}

unsafe fn get_ray_tracing_shader_group_stack_size<'a>(
Expand All @@ -2042,9 +2041,9 @@ impl d::Device<B> for super::Device {
.extension_fns
.ray_tracing_pipeline
.as_ref()
.expect("TODO msg")
.expect("Feature RAY_TRACING_PIPELINE must be enabled to call get_ray_tracing_shader_group_stack_size")
.unwrap_extension()
.get_ray_tracing_shader_group_stack_size(
self.shared.raw.handle(),
pipeline.0,
group,
conv::map_group_shader(group_shader),
Expand Down
1 change: 0 additions & 1 deletion src/backend/vulkan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ use hal::{
window::{OutOfDate, PresentError, Suboptimal, SurfaceLost},
Features,
};
use vk::PhysicalDeviceProperties2;

use std::{
borrow::Cow,
Expand Down
6 changes: 3 additions & 3 deletions src/hal/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,14 +744,14 @@ pub trait Device<B: Backend>: fmt::Debug + Any + Send + Sync {
// }

/// TODO docs
// `data` must be at least `shaderGroupHandleCaptureReplaySize * groupCount`
// `data_size` must be at least `shaderGroupHandleSize * groupCount`
unsafe fn get_ray_tracing_shader_group_handles<'a>(
&self,
_pipeline: &'a B::RayTracingPipeline,
_first_group: u32,
_group_count: u32,
_data: &mut [u8],
) -> Result<(), OutOfMemory> {
_data_size: usize,
) -> Result<Vec<u8>, OutOfMemory> {
unimplemented!()
}

Expand Down
2 changes: 1 addition & 1 deletion src/hal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ bitflags! {

/// Supports acceleration structures.
///
/// Requires `RAY_TRACING_PIPELINES` or `RAY_QUERY` to also be enabled.
/// Requires `RAY_TRACING_PIPELINE` or `RAY_QUERY` to also be enabled.
const ACCELERATION_STRUCTURE = 0x0000_0008 << 96;
/// Supports a command to indirectly build an acceleration structure.
// TODO should this be part of `AccelerationStructureProperties`? The diff would be if app can depend on this feature vs. check for its availability.
Expand Down
Loading

0 comments on commit 2b1fdf0

Please sign in to comment.