Skip to content

Commit

Permalink
[wgpu-core] ray tracing: use error handling helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Nov 15, 2024
1 parent 1abf3fe commit 66518aa
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 181 deletions.
154 changes: 37 additions & 117 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ impl Global {

let device = &cmd_buf.device;

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
return Err(BuildAccelerationStructureError::MissingFeature);
}
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let build_command_index = NonZeroU64::new(
device
Expand Down Expand Up @@ -200,18 +195,13 @@ impl Global {
let mut tlas_buf_storage = Vec::new();

for entry in tlas_iter {
let instance_buffer = match hub.buffers.get(entry.instance_buffer_id).get() {
Ok(buffer) => buffer,
Err(_) => {
return Err(BuildAccelerationStructureError::InvalidBufferId);
}
};
let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
let data = cmd_buf_data.trackers.buffers.set_single(
&instance_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
tlas_buf_storage.push(TlasBufferStore {
buffer: instance_buffer.clone(),
buffer: instance_buffer,
transition: data,
entry: entry.clone(),
});
Expand All @@ -222,14 +212,9 @@ impl Global {
let instance_buffer = {
let (instance_buffer, instance_pending) =
(&mut tlas_buf.buffer, &mut tlas_buf.transition);
let instance_raw = instance_buffer.raw.get(&snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(instance_buffer.error_ident()),
)?;
if !instance_buffer.usage.contains(BufferUsages::TLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingTlasInputUsageFlag(
instance_buffer.error_ident(),
));
}
let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;

if let Some(barrier) = instance_pending
.take()
.map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
Expand All @@ -239,11 +224,7 @@ impl Global {
instance_raw
};

let tlas = hub
.tlas_s
.get(entry.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
Expand All @@ -267,7 +248,7 @@ impl Global {
tlas,
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(instance_buffer.as_ref()),
buffer: Some(instance_buffer),
offset: 0,
count: entry.instance_count,
},
Expand Down Expand Up @@ -307,9 +288,7 @@ impl Global {
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()),
)?,
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down Expand Up @@ -374,12 +353,7 @@ impl Global {

let device = &cmd_buf.device;

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
return Err(BuildAccelerationStructureError::MissingFeature);
}
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let build_command_index = NonZeroU64::new(
device
Expand Down Expand Up @@ -512,17 +486,13 @@ impl Global {
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();

for package in tlas_iter {
let tlas = hub
.tlas_s
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
let tlas = hub.tlas_s.get(package.tlas_id).get()?;
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());

tlas_lock_store.push((Some(package), tlas.clone()))
tlas_lock_store.push((Some(package), tlas))
}

let mut scratch_buffer_tlas_size = 0;
Expand All @@ -549,12 +519,7 @@ impl Global {
tlas.error_ident(),
));
}
let blas = hub
.blas_s
.get(instance.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasIdForInstance)?
.clone();
let blas = hub.blas_s.get(instance.blas_id).get()?;

cmd_buf_data.trackers.blas_s.set_single(blas.clone());

Expand All @@ -569,7 +534,7 @@ impl Global {
dependencies.push(blas.clone());

cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),
blas,
kind: crate::ray_tracing::BlasActionKind::Use,
});
}
Expand Down Expand Up @@ -642,13 +607,7 @@ impl Global {
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas
.raw
.get(&snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidTlas(
tlas.error_ident(),
))?
.as_ref(),
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down Expand Up @@ -828,9 +787,7 @@ impl CommandBufferMutable {
action.tlas.error_ident(),
));
}
if blas.raw.get(snatch_guard).is_none() {
return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident()));
}
blas.try_raw(snatch_guard)?;
}
}
}
Expand All @@ -850,11 +807,7 @@ fn iter_blas<'a>(
) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new();
for entry in blas_iter {
let blas = hub
.blas_s
.get(entry.blas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
let blas = hub.blas_s.get(entry.blas_id).get()?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_blas(&blas);
Expand Down Expand Up @@ -937,19 +890,13 @@ fn iter_blas<'a>(
blas.error_ident(),
));
}
let vertex_buffer = match hub.buffers.get(mesh.vertex_buffer).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
&vertex_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
let index_data = if let Some(index_id) = mesh.index_buffer {
let index_buffer = match hub.buffers.get(index_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let index_buffer = hub.buffers.get(index_id).get()?;
if mesh.index_buffer_offset.is_none()
|| mesh.size.index_count.is_none()
|| mesh.size.index_count.is_none()
Expand All @@ -962,15 +909,12 @@ fn iter_blas<'a>(
&index_buffer,
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
);
Some((index_buffer.clone(), data))
Some((index_buffer, data))
} else {
None
};
let transform_data = if let Some(transform_id) = mesh.transform_buffer {
let transform_buffer = match hub.buffers.get(transform_id).get() {
Ok(buffer) => buffer,
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
};
let transform_buffer = hub.buffers.get(transform_id).get()?;
if mesh.transform_buffer_offset.is_none() {
return Err(BuildAccelerationStructureError::MissingAssociatedData(
transform_buffer.error_ident(),
Expand All @@ -985,7 +929,7 @@ fn iter_blas<'a>(
None
};
temp_buffer.push(TriangleBufferStore {
vertex_buffer: vertex_buffer.clone(),
vertex_buffer,
vertex_transition: vertex_pending,
index_buffer_transition: index_data,
transform_buffer_transition: transform_data,
Expand All @@ -995,7 +939,7 @@ fn iter_blas<'a>(
}

if let Some(last) = temp_buffer.last_mut() {
last.ending_blas = Some(blas.clone());
last.ending_blas = Some(blas);
buf_storage.append(&mut temp_buffer);
}
}
Expand All @@ -1020,14 +964,9 @@ fn iter_buffers<'a, 'b>(
let mesh = &buf.geometry;
let vertex_buffer = {
let vertex_buffer = buf.vertex_buffer.as_ref();
let vertex_raw = vertex_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()),
)?;
if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
vertex_buffer.error_ident(),
));
}
let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = buf
.vertex_transition
.take()
Expand All @@ -1047,10 +986,7 @@ fn iter_buffers<'a, 'b>(
let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
cmd_buf_data.buffer_memory_init_actions.extend(
vertex_buffer.initialization_status.read().create_action(
&hub.buffers
.get(mesh.vertex_buffer)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBufferId)?,
&hub.buffers.get(mesh.vertex_buffer).get()?,
vertex_buffer_offset
..(vertex_buffer_offset
+ mesh.size.vertex_count as u64 * mesh.vertex_stride),
Expand All @@ -1062,14 +998,9 @@ fn iter_buffers<'a, 'b>(
let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
buf.index_buffer_transition
{
let index_raw = index_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()),
)?;
if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
index_buffer.error_ident(),
));
}
let index_raw = index_buffer.try_raw(snatch_guard)?;
index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = index_pending
.take()
.map(|pending| pending.into_hal(index_buffer, snatch_guard))
Expand Down Expand Up @@ -1125,14 +1056,9 @@ fn iter_buffers<'a, 'b>(
transform_buffer.error_ident(),
));
}
let transform_raw = transform_buffer.raw.get(snatch_guard).ok_or(
BuildAccelerationStructureError::InvalidBuffer(transform_buffer.error_ident()),
)?;
if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
transform_buffer.error_ident(),
));
}
let transform_raw = transform_buffer.try_raw(snatch_guard)?;
transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;

if let Some(barrier) = transform_pending
.take()
.map(|pending| pending.into_hal(transform_buffer, snatch_guard))
Expand Down Expand Up @@ -1166,7 +1092,7 @@ fn iter_buffers<'a, 'b>(
};

let triangles = hal::AccelerationStructureTriangles {
vertex_buffer: Some(vertex_buffer.as_ref()),
vertex_buffer: Some(vertex_buffer),
vertex_format: mesh.size.vertex_format,
first_vertex: mesh.first_vertex,
vertex_count: mesh.size.vertex_count,
Expand All @@ -1175,13 +1101,13 @@ fn iter_buffers<'a, 'b>(
dyn hal::DynBuffer,
> {
format: mesh.size.index_format.unwrap(),
buffer: Some(index_buffer.as_ref()),
buffer: Some(index_buffer),
offset: mesh.index_buffer_offset.unwrap() as u32,
count: mesh.size.index_count.unwrap(),
}),
transform: transform_buffer.map(|transform_buffer| {
hal::AccelerationStructureTriangleTransform {
buffer: transform_buffer.as_ref(),
buffer: transform_buffer,
offset: mesh.transform_buffer_offset.unwrap() as u32,
}
}),
Expand Down Expand Up @@ -1231,13 +1157,7 @@ fn map_blas<'a>(
mode: hal::AccelerationStructureBuildMode::Build,
flags: blas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: blas
.raw
.get(snatch_guard)
.ok_or(BuildAccelerationStructureError::InvalidBlas(
blas.error_ident(),
))?
.as_ref(),
destination_acceleration_structure: blas.try_raw(snatch_guard)?,
scratch_buffer,
scratch_buffer_offset: *scratch_buffer_offset,
})
Expand Down
18 changes: 4 additions & 14 deletions wgpu-core/src/device/ray_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ impl Device {
blas_desc: &resource::BlasDescriptor,
sizes: wgt::BlasGeometrySizeDescriptors,
) -> Result<Arc<resource::Blas>, CreateBlasError> {
self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let size_info = match &sizes {
wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
let mut entries =
Expand Down Expand Up @@ -109,6 +111,8 @@ impl Device {
self: &Arc<Self>,
desc: &resource::TlasDescriptor,
) -> Result<Arc<resource::Tlas>, CreateTlasError> {
self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;

let size_info = unsafe {
self.raw().get_acceleration_structure_build_sizes(
&hal::GetAccelerationStructureBuildSizesDescriptor {
Expand Down Expand Up @@ -185,13 +189,6 @@ impl Global {
Err(err) => break 'error CreateBlasError::Device(err),
};

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
break 'error CreateBlasError::MissingFeature;
}

#[cfg(feature = "trace")]
if let Some(trace) = device.trace.lock().as_mut() {
trace.add(trace::Action::CreateBlas {
Expand Down Expand Up @@ -236,13 +233,6 @@ impl Global {
Err(e) => break 'error CreateTlasError::Device(e),
}

if !device
.features
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
{
break 'error CreateTlasError::MissingFeature;
}

#[cfg(feature = "trace")]
if let Some(trace) = device.trace.lock().as_mut() {
trace.add(trace::Action::CreateTlas {
Expand Down
Loading

0 comments on commit 66518aa

Please sign in to comment.