Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate errors when openning/closing a command encoder #4999

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions wgpu-core/src/command/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::device::trace::Command as TraceCommand;
use crate::{
api_log,
command::CommandBuffer,
device::DeviceError,
get_lowest_common_denom,
global::Global,
hal_api::HalApi,
Expand Down Expand Up @@ -66,6 +67,8 @@ whereas subesource range specified start {subresource_base_array_layer} and coun
subresource_base_array_layer: u32,
subresource_array_layer_count: Option<u32>,
},
#[error(transparent)]
Device(#[from] DeviceError),
}

impl<G: GlobalIdentityHandlerFactory> Global<G> {
Expand Down Expand Up @@ -149,7 +152,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

// actual hal barrier & operation
let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
unsafe {
cmd_buf_raw.transition_buffers(dst_barrier.into_iter());
cmd_buf_raw.clear_buffer(dst_raw, offset..end);
Expand Down Expand Up @@ -228,7 +231,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if !device.is_valid() {
return Err(ClearError::InvalidDevice(cmd_buf.device.as_info().id()));
}
let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker();
let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker()?;

clear_texture(
&dst_texture,
Expand Down
21 changes: 12 additions & 9 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::device::DeviceError;
use crate::resource::Resource;
use crate::snatch::SnatchGuard;
use crate::{
Expand Down Expand Up @@ -186,6 +187,8 @@ pub enum DispatchError {
/// Error encountered when performing a compute pass.
#[derive(Clone, Debug, Error)]
pub enum ComputePassErrorInner {
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Bind group at index {0:?} is invalid")]
Expand Down Expand Up @@ -366,17 +369,17 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let init_scope = PassErrorScope::Pass(encoder_id);
let pass_scope = PassErrorScope::Pass(encoder_id);

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
let device = &cmd_buf.device;
if !device.is_valid() {
return Err(ComputePassErrorInner::InvalidDevice(
cmd_buf.device.as_info().id(),
))
.map_pass_err(init_scope);
.map_pass_err(pass_scope);
}

let mut cmd_buf_data = cmd_buf.data.lock();
Expand All @@ -399,10 +402,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;
// will be reset to true if recording is done without errors
*status = CommandEncoderStatus::Error;
let raw = encoder.open();
let raw = encoder.open().map_pass_err(pass_scope)?;

let bind_group_guard = hub.bind_groups.read();
let pipeline_guard = hub.compute_pipelines.read();
Expand All @@ -426,7 +429,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.query_sets
.add_single(&*query_set_guard, tw.query_set)
.ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(init_scope)?;
.map_pass_err(pass_scope)?;

// Unlike in render passes we can't delay resetting the query sets since
// there is no auxillary pass.
Expand Down Expand Up @@ -862,12 +865,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
*status = CommandEncoderStatus::Recording;

// Stop the current command buffer.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;

// Create a new command buffer, which we will insert _before_ the body of the compute pass.
//
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder.open();
let transit = encoder.open().map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
transit,
Expand All @@ -881,7 +884,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap();
encoder.close_and_swap().map_pass_err(pass_scope)?;

Ok(())
}
Expand Down
59 changes: 37 additions & 22 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use self::{

use self::memory_init::CommandBufferTextureMemoryActions;

use crate::device::Device;
use crate::device::{Device, DeviceError};
use crate::error::{ErrorFormatter, PrettyError};
use crate::hub::Hub;
use crate::id::CommandBufferId;
Expand Down Expand Up @@ -58,20 +58,24 @@ pub(crate) struct CommandEncoder<A: HalApi> {
//TODO: handle errors better
impl<A: HalApi> CommandEncoder<A> {
/// Closes the live encoder
fn close_and_swap(&mut self) {
fn close_and_swap(&mut self) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let new = unsafe { self.raw.end_encoding().unwrap() };
let new = unsafe { self.raw.end_encoding()? };
self.list.insert(self.list.len() - 1, new);
}

Ok(())
}

fn close(&mut self) {
fn close(&mut self) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let cmd_buf = unsafe { self.raw.end_encoding().unwrap() };
let cmd_buf = unsafe { self.raw.end_encoding()? };
self.list.push(cmd_buf);
}

Ok(())
}

fn discard(&mut self) {
Expand All @@ -81,18 +85,21 @@ impl<A: HalApi> CommandEncoder<A> {
}
}

fn open(&mut self) -> &mut A::CommandEncoder {
fn open(&mut self) -> Result<&mut A::CommandEncoder, DeviceError> {
if !self.is_open {
self.is_open = true;
let label = self.label.as_deref();
unsafe { self.raw.begin_encoding(label).unwrap() };
unsafe { self.raw.begin_encoding(label)? };
}
&mut self.raw

Ok(&mut self.raw)
}

fn open_pass(&mut self, label: Option<&str>) {
fn open_pass(&mut self, label: Option<&str>) -> Result<(), DeviceError> {
self.is_open = true;
unsafe { self.raw.begin_encoding(label).unwrap() };
unsafe { self.raw.begin_encoding(label)? };

Ok(())
}
}

Expand All @@ -119,10 +126,13 @@ pub struct CommandBufferMutable<A: HalApi> {
}

impl<A: HalApi> CommandBufferMutable<A> {
pub(crate) fn open_encoder_and_tracker(&mut self) -> (&mut A::CommandEncoder, &mut Tracker<A>) {
let encoder = self.encoder.open();
pub(crate) fn open_encoder_and_tracker(
&mut self,
) -> Result<(&mut A::CommandEncoder, &mut Tracker<A>), DeviceError> {
let encoder = self.encoder.open()?;
let tracker = &mut self.trackers;
(encoder, tracker)

Ok((encoder, tracker))
}
}

Expand Down Expand Up @@ -401,6 +411,8 @@ pub enum CommandEncoderError {
Invalid,
#[error("Command encoder must be active")]
NotRecording,
#[error(transparent)]
Device(#[from] DeviceError),
}

impl<G: GlobalIdentityHandlerFactory> Global<G> {
Expand All @@ -419,12 +431,15 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
cmd_buf_data.encoder.close();
cmd_buf_data.status = CommandEncoderStatus::Finished;
//Note: if we want to stop tracking the swapchain texture view,
// this is the place to do it.
log::trace!("Command buffer {:?}", encoder_id);
None
if let Err(e) = cmd_buf_data.encoder.close() {
Some(e.into())
} else {
cmd_buf_data.status = CommandEncoderStatus::Finished;
//Note: if we want to stop tracking the swapchain texture view,
// this is the place to do it.
log::trace!("Command buffer {:?}", encoder_id);
None
}
}
CommandEncoderStatus::Finished => Some(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => {
Expand Down Expand Up @@ -457,7 +472,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
list.push(TraceCommand::PushDebugGroup(label.to_string()));
}

let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
if !self
.instance
.flags
Expand Down Expand Up @@ -494,7 +509,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
unsafe {
cmd_buf_raw.insert_debug_marker(label);
}
Expand All @@ -520,7 +535,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
list.push(TraceCommand::PopDebugGroup);
}

let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
if !self
.instance
.flags
Expand Down
7 changes: 5 additions & 2 deletions wgpu-core/src/command/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use hal::CommandEncoder as _;
use crate::device::trace::Command as TraceCommand;
use crate::{
command::{CommandBuffer, CommandEncoderError},
device::DeviceError,
global::Global,
hal_api::HalApi,
id::{self, Id, TypedId},
Expand Down Expand Up @@ -104,6 +105,8 @@ impl From<wgt::QueryType> for SimplifiedQueryType {
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum QueryError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Error encountered while trying to use queries")]
Expand Down Expand Up @@ -367,7 +370,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let encoder = &mut cmd_buf_data.encoder;
let tracker = &mut cmd_buf_data.trackers;

let raw_encoder = encoder.open();
let raw_encoder = encoder.open()?;

let query_set_guard = hub.query_sets.read();
let query_set = tracker
Expand Down Expand Up @@ -409,7 +412,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let encoder = &mut cmd_buf_data.encoder;
let tracker = &mut cmd_buf_data.trackers;
let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
let raw_encoder = encoder.open();
let raw_encoder = encoder.open()?;

if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
Expand Down
20 changes: 10 additions & 10 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,11 +1312,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
let label = hal_label(base.label, self.instance.flags);

let init_scope = PassErrorScope::Pass(encoder_id);
let pass_scope = PassErrorScope::Pass(encoder_id);

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
let device = &cmd_buf.device;
let snatch_guard = device.snatchable_lock.read();

Expand All @@ -1336,7 +1336,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

if !device.is_valid() {
return Err(DeviceError::Lost).map_pass_err(init_scope);
return Err(DeviceError::Lost).map_pass_err(pass_scope);
}

let encoder = &mut cmd_buf_data.encoder;
Expand All @@ -1349,10 +1349,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;
// We will reset this to `Recording` if we succeed, acts as a fail-safe.
*status = CommandEncoderStatus::Error;
encoder.open_pass(label);
encoder.open_pass(label).map_pass_err(pass_scope)?;

let bundle_guard = hub.render_bundles.read();
let bind_group_guard = hub.bind_groups.read();
Expand Down Expand Up @@ -1383,7 +1383,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&*texture_guard,
&*query_set_guard,
)
.map_pass_err(init_scope)?;
.map_pass_err(pass_scope)?;

tracker.set_size(
Some(&*buffer_guard),
Expand Down Expand Up @@ -2364,9 +2364,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

log::trace!("Merging renderpass into cmd_buf {:?}", encoder_id);
let (trackers, pending_discard_init_fixups) =
info.finish(raw).map_pass_err(init_scope)?;
info.finish(raw).map_pass_err(pass_scope)?;

encoder.close();
encoder.close().map_pass_err(pass_scope)?;
(trackers, pending_discard_init_fixups)
};

Expand All @@ -2381,7 +2381,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let tracker = &mut cmd_buf_data.trackers;

{
let transit = encoder.open();
let transit = encoder.open().map_pass_err(pass_scope)?;

fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
Expand Down Expand Up @@ -2409,7 +2409,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

*status = CommandEncoderStatus::Recording;
encoder.close_and_swap();
encoder.close_and_swap().map_pass_err(pass_scope)?;

Ok(())
}
Expand Down
Loading