Skip to content

Commit

Permalink
Check that raw buffers and raw bind groups are valid (#4895)
Browse files Browse the repository at this point in the history
  • Loading branch information
nical authored Dec 20, 2023
1 parent 0e35006 commit 0524c88
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 44 deletions.
10 changes: 8 additions & 2 deletions wgpu-core/src/binding_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
init_tracker::{BufferInitTrackerAction, TextureInitTrackerAction},
resource::{Resource, ResourceInfo, ResourceType},
resource_log,
snatch::SnatchGuard,
track::{BindGroupStates, UsageConflict},
validation::{MissingBufferUsageError, MissingTextureUsageError},
FastHashMap, Label,
Expand Down Expand Up @@ -835,8 +836,13 @@ impl<A: HalApi> Drop for BindGroup<A> {
}

impl<A: HalApi> BindGroup<A> {
pub(crate) fn raw(&self) -> &A::BindGroup {
self.raw.as_ref().unwrap()
pub(crate) fn raw(&self, guard: &SnatchGuard) -> Option<&A::BindGroup> {
for buffer in &self.used_buffer_ranges {
// Clippy insist on writing it this way. The idea is to return None
// if any of the raw buffer is not valid anymore.
let _ = buffer.buffer.raw(guard)?;
}
self.raw.as_ref()
}
pub(crate) fn validate_dynamic_bindings(
&self,
Expand Down
34 changes: 29 additions & 5 deletions wgpu-core/src/command/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,8 @@ pub enum CreateRenderBundleError {
pub enum ExecutionError {
#[error("Buffer {0:?} is destroyed")]
DestroyedBuffer(id::BufferId),
#[error("BindGroup {0:?} is invalid")]
InvalidBindGroup(id::BindGroupId),
#[error("Using {0} in a render bundle is not implemented")]
Unimplemented(&'static str),
}
Expand All @@ -722,6 +724,9 @@ impl PrettyError for ExecutionError {
Self::DestroyedBuffer(id) => {
fmt.buffer_label(&id);
}
Self::InvalidBindGroup(id) => {
fmt.bind_group_label(&id);
}
Self::Unimplemented(_reason) => {}
};
}
Expand Down Expand Up @@ -796,11 +801,14 @@ impl<A: HalApi> RenderBundle<A> {
} => {
let bind_groups = trackers.bind_groups.read();
let bind_group = bind_groups.get(bind_group_id).unwrap();
let raw_bg = bind_group
.raw(&snatch_guard)
.ok_or(ExecutionError::InvalidBindGroup(bind_group_id))?;
unsafe {
raw.set_bind_group(
pipeline_layout.as_ref().unwrap().raw(),
index,
bind_group.raw(),
raw_bg,
&offsets[..num_dynamic_offsets as usize],
)
};
Expand All @@ -820,7 +828,11 @@ impl<A: HalApi> RenderBundle<A> {
size,
} => {
let buffers = trackers.buffers.read();
let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard);
let buffer: &A::Buffer = buffers
.get(buffer_id)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?
.raw(&snatch_guard)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?;
let bb = hal::BufferBinding {
buffer,
offset,
Expand All @@ -835,7 +847,11 @@ impl<A: HalApi> RenderBundle<A> {
size,
} => {
let buffers = trackers.buffers.read();
let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard);
let buffer = buffers
.get(buffer_id)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?
.raw(&snatch_guard)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?;
let bb = hal::BufferBinding {
buffer,
offset,
Expand Down Expand Up @@ -914,7 +930,11 @@ impl<A: HalApi> RenderBundle<A> {
indexed: false,
} => {
let buffers = trackers.buffers.read();
let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard);
let buffer = buffers
.get(buffer_id)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?
.raw(&snatch_guard)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?;
unsafe { raw.draw_indirect(buffer, offset, 1) };
}
RenderCommand::MultiDrawIndirect {
Expand All @@ -924,7 +944,11 @@ impl<A: HalApi> RenderBundle<A> {
indexed: true,
} => {
let buffers = trackers.buffers.read();
let buffer = buffers.get(buffer_id).unwrap().raw(&snatch_guard);
let buffer = buffers
.get(buffer_id)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?
.raw(&snatch_guard)
.ok_or(ExecutionError::DestroyedBuffer(buffer_id))?;
unsafe { raw.draw_indexed_indirect(buffer, offset, 1) };
}
RenderCommand::MultiDrawIndirect { .. }
Expand Down
19 changes: 11 additions & 8 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ pub enum DispatchError {
pub enum ComputePassErrorInner {
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Bind group {0:?} is invalid")]
InvalidBindGroup(id::BindGroupId),
#[error("Bind group at index {0:?} is invalid")]
InvalidBindGroup(usize),
#[error("Device {0:?} is invalid")]
InvalidDevice(DeviceId),
#[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
Expand Down Expand Up @@ -232,9 +232,6 @@ impl PrettyError for ComputePassErrorInner {
fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
fmt.error(self);
match *self {
Self::InvalidBindGroup(id) => {
fmt.bind_group_label(&id);
}
Self::InvalidPipeline(id) => {
fmt.compute_pipeline_label(&id);
}
Expand Down Expand Up @@ -520,7 +517,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let bind_group = tracker
.bind_groups
.add_single(&*bind_group_guard, bind_group_id)
.ok_or(ComputePassErrorInner::InvalidBindGroup(bind_group_id))
.ok_or(ComputePassErrorInner::InvalidBindGroup(index as usize))
.map_pass_err(scope)?;
bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
Expand Down Expand Up @@ -550,7 +547,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.raw();
let raw_bg = group
.raw(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidBindGroup(i))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline_layout,
Expand Down Expand Up @@ -594,7 +594,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.raw();
let raw_bg = group
.raw(&snatch_guard)
.ok_or(ComputePassErrorInner::InvalidBindGroup(i))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline.layout.raw(),
Expand Down
6 changes: 5 additions & 1 deletion wgpu-core/src/command/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
MemoryInitKind::ImplicitlyInitialized,
));

let raw_dst_buffer = dst_buffer
.raw(&snatch_guard)
.ok_or(QueryError::InvalidBuffer(destination))?;

unsafe {
raw_encoder.transition_buffers(dst_barrier.into_iter());
raw_encoder.copy_query_results(
query_set.raw(),
start_query..end_query,
dst_buffer.raw(&snatch_guard),
raw_dst_buffer,
destination_offset,
wgt::BufferSize::new_unchecked(stride as u64),
);
Expand Down
15 changes: 13 additions & 2 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ pub enum RenderPassErrorInner {
SurfaceTextureDropped,
#[error("Not enough memory left for render pass")]
OutOfMemory,
#[error("The bind group at index {0:?} is invalid")]
InvalidBindGroup(usize),
#[error("Unable to clear non-present/read-only depth")]
InvalidDepthOps,
#[error("Unable to clear non-present/read-only stencil")]
Expand Down Expand Up @@ -1484,7 +1486,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.raw();
let raw_bg = group
.raw(&snatch_guard)
.ok_or(RenderPassErrorInner::InvalidBindGroup(i))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline_layout,
Expand Down Expand Up @@ -1562,7 +1567,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.raw();
let raw_bg = group
.raw(&snatch_guard)
.ok_or(RenderPassErrorInner::InvalidBindGroup(i))
.map_pass_err(scope)?;
unsafe {
raw.set_bind_group(
pipeline.layout.raw(),
Expand Down Expand Up @@ -2332,6 +2340,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
ExecutionError::DestroyedBuffer(id) => {
RenderCommandError::DestroyedBuffer(id)
}
ExecutionError::InvalidBindGroup(id) => {
RenderCommandError::InvalidBindGroup(id)
}
ExecutionError::Unimplemented(what) => {
RenderCommandError::Unimplemented(what)
}
Expand Down
15 changes: 8 additions & 7 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,8 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
};

let snatch_guard = device.snatchable_lock.read();
let mapping = match unsafe {
device
.raw()
.map_buffer(stage.raw(&snatch_guard), 0..stage.size)
} {
let stage_raw = stage.raw(&snatch_guard).unwrap();
let mapping = match unsafe { device.raw().map_buffer(stage_raw, 0..stage.size) } {
Ok(mapping) => mapping,
Err(e) => {
to_destroy.push(buffer);
Expand Down Expand Up @@ -401,7 +398,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
});
}

let raw_buf = buffer.raw(&snatch_guard);
let raw_buf = buffer
.raw(&snatch_guard)
.ok_or(BufferAccessError::Destroyed)?;
unsafe {
let mapping = device
.raw()
Expand Down Expand Up @@ -451,7 +450,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
check_buffer_usage(buffer.usage, wgt::BufferUsages::MAP_READ)?;
//assert!(buffer isn't used by the GPU);

let raw_buf = buffer.raw(&snatch_guard);
let raw_buf = buffer
.raw(&snatch_guard)
.ok_or(BufferAccessError::Destroyed)?;
unsafe {
let mapping = device
.raw()
Expand Down
14 changes: 6 additions & 8 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,17 @@ fn map_buffer<A: HalApi>(
kind: HostMap,
) -> Result<ptr::NonNull<u8>, BufferAccessError> {
let snatch_guard = buffer.device.snatchable_lock.read();
let raw_buffer = buffer
.raw(&snatch_guard)
.ok_or(BufferAccessError::Destroyed)?;
let mapping = unsafe {
raw.map_buffer(buffer.raw(&snatch_guard), offset..offset + size)
raw.map_buffer(raw_buffer, offset..offset + size)
.map_err(DeviceError::from)?
};

*buffer.sync_mapped_writes.lock() = match kind {
HostMap::Read if !mapping.is_coherent => unsafe {
raw.invalidate_mapped_ranges(
buffer.raw(&snatch_guard),
iter::once(offset..offset + size),
);
raw.invalidate_mapped_ranges(raw_buffer, iter::once(offset..offset + size));
None
},
HostMap::Write if !mapping.is_coherent => Some(offset..offset + size),
Expand Down Expand Up @@ -390,9 +390,7 @@ fn map_buffer<A: HalApi>(
mapped[fill_range].fill(0);

if zero_init_needs_flush_now {
unsafe {
raw.flush_mapped_ranges(buffer.raw(&snatch_guard), iter::once(uninitialized))
};
unsafe { raw.flush_mapped_ranges(raw_buffer, iter::once(uninitialized)) };
}
}

Expand Down
20 changes: 9 additions & 11 deletions wgpu-core/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ impl<A: HalApi> Drop for Buffer<A> {
}

impl<A: HalApi> Buffer<A> {
pub(crate) fn raw(&self, guard: &SnatchGuard) -> &A::Buffer {
self.raw.get(guard).unwrap()
pub(crate) fn raw(&self, guard: &SnatchGuard) -> Option<&A::Buffer> {
self.raw.get(guard)
}

pub(crate) fn buffer_unmap_inner(
Expand All @@ -434,6 +434,9 @@ impl<A: HalApi> Buffer<A> {

let device = &self.device;
let snatch_guard = device.snatchable_lock.read();
let raw_buf = self
.raw(&snatch_guard)
.ok_or(BufferAccessError::Destroyed)?;
let buffer_id = self.info.id();
log::debug!("Buffer {:?} map state -> Idle", buffer_id);
match mem::replace(&mut *self.map_state.lock(), resource::BufferMapState::Idle) {
Expand All @@ -458,17 +461,12 @@ impl<A: HalApi> Buffer<A> {
if needs_flush {
unsafe {
device.raw().flush_mapped_ranges(
stage_buffer.raw(&snatch_guard),
stage_buffer.raw(&snatch_guard).unwrap(),
iter::once(0..self.size),
);
}
}

let raw_buf = self
.raw
.get(&snatch_guard)
.ok_or(BufferAccessError::Destroyed)?;

self.info
.use_at(device.active_submission_index.load(Ordering::Relaxed) + 1);
let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy {
Expand All @@ -477,7 +475,7 @@ impl<A: HalApi> Buffer<A> {
size,
});
let transition_src = hal::BufferBarrier {
buffer: stage_buffer.raw(&snatch_guard),
buffer: stage_buffer.raw(&snatch_guard).unwrap(),
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
};
let transition_dst = hal::BufferBarrier {
Expand All @@ -493,7 +491,7 @@ impl<A: HalApi> Buffer<A> {
);
if self.size > 0 {
encoder.copy_buffer_to_buffer(
stage_buffer.raw(&snatch_guard),
stage_buffer.raw(&snatch_guard).unwrap(),
raw_buf,
region.into_iter(),
);
Expand Down Expand Up @@ -528,7 +526,7 @@ impl<A: HalApi> Buffer<A> {
unsafe {
device
.raw()
.unmap_buffer(self.raw(&snatch_guard))
.unmap_buffer(raw_buf)
.map_err(DeviceError::from)?
};
}
Expand Down

0 comments on commit 0524c88

Please sign in to comment.