Skip to content

Commit

Permalink
Fix ownership management of query sets on compute passes for write_ti…
Browse files Browse the repository at this point in the history
…mestamp, timestamp_writes (on desc) and pipeline statistic queries
  • Loading branch information
Wumpf committed May 29, 2024
1 parent 71e455d commit 60180ae
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 181 deletions.
182 changes: 115 additions & 67 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use crate::{
compute_command::{ArcComputeCommand, ComputeCommand},
end_pipeline_statistics_query,
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus,
MapPassErr, PassErrorScope, QueryUseError, StateChange,
validate_and_begin_pipeline_statistics_query, BasePass, BindGroupStateChange,
CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope,
QueryUseError, StateChange,
},
device::{DeviceError, MissingDownlevelFlags, MissingFeatures},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hal_label,
id::{self},
hal_label, id,
init_tracker::MemoryInitKind,
resource::{self, Resource},
snatch::SnatchGuard,
Expand Down Expand Up @@ -48,7 +48,7 @@ pub struct ComputePass<A: HalApi> {
/// If it is none, this pass is invalid and any operation on it will return an error.
parent: Option<Arc<CommandBuffer<A>>>,

timestamp_writes: Option<ComputePassTimestampWrites>,
timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,

// Resource binding dedupe state.
current_bind_groups: BindGroupStateChange,
Expand All @@ -57,11 +57,16 @@ pub struct ComputePass<A: HalApi> {

impl<A: HalApi> ComputePass<A> {
/// If the parent command buffer is invalid, the returned pass will be invalid.
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: &ComputePassDescriptor) -> Self {
fn new(parent: Option<Arc<CommandBuffer<A>>>, desc: ArcComputePassDescriptor<A>) -> Self {
let ArcComputePassDescriptor {
label,
timestamp_writes,
} = desc;

Self {
base: Some(BasePass::new(&desc.label)),
base: Some(BasePass::new(label)),
parent,
timestamp_writes: desc.timestamp_writes.cloned(),
timestamp_writes,

current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
Expand Down Expand Up @@ -107,13 +112,29 @@ pub struct ComputePassTimestampWrites {
pub end_of_pass_write_index: Option<u32>,
}

/// Describes the writing of timestamp values in a compute pass with the query set resolved.
struct ArcComputePassTimestampWrites<A: HalApi> {
/// The query set to write the timestamps to.
pub query_set: Arc<resource::QuerySet<A>>,
/// The index of the query set at which a start timestamp of this pass is written, if any.
pub beginning_of_pass_write_index: Option<u32>,
/// The index of the query set at which an end timestamp of this pass is written, if any.
pub end_of_pass_write_index: Option<u32>,
}

#[derive(Clone, Debug, Default)]
pub struct ComputePassDescriptor<'a> {
pub label: Label<'a>,
/// Defines where and when timestamp values will be written for this pass.
pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
}

struct ArcComputePassDescriptor<'a, A: HalApi> {
pub label: &'a Label<'a>,
/// Defines where and when timestamp values will be written for this pass.
pub timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
}

#[derive(Clone, Debug, Error, Eq, PartialEq)]
#[non_exhaustive]
pub enum DispatchError {
Expand Down Expand Up @@ -310,13 +331,37 @@ impl Global {
pub fn command_encoder_create_compute_pass<A: HalApi>(
&self,
encoder_id: id::CommandEncoderId,
desc: &ComputePassDescriptor,
desc: &ComputePassDescriptor<'_>,
) -> (ComputePass<A>, Option<CommandEncoderError>) {
let hub = A::hub(self);

let mut arc_desc = ArcComputePassDescriptor {
label: &desc.label,
timestamp_writes: None, // Handle only once we resolved the encoder.
};

match CommandBuffer::lock_encoder(hub, encoder_id) {
Ok(cmd_buf) => (ComputePass::new(Some(cmd_buf), desc), None),
Err(err) => (ComputePass::new(None, desc), Some(err)),
Ok(cmd_buf) => {
arc_desc.timestamp_writes = if let Some(tw) = desc.timestamp_writes {
let Ok(query_set) = hub.query_sets.read().get_owned(tw.query_set) else {
return (
ComputePass::new(None, arc_desc),
Some(CommandEncoderError::InvalidTimestampWritesQuerySetId),
);
};

Some(ArcComputePassTimestampWrites {
query_set,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};

(ComputePass::new(Some(cmd_buf), arc_desc), None)
}
Err(err) => (ComputePass::new(None, arc_desc), Some(err)),
}
}

Expand Down Expand Up @@ -349,7 +394,7 @@ impl Global {
.take()
.ok_or(ComputePassErrorInner::PassEnded)
.map_pass_err(scope)?;
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.as_ref())
self.compute_pass_end_impl(parent, base, pass.timestamp_writes.take())
}

#[doc(hidden)]
Expand All @@ -360,11 +405,26 @@ impl Global {
timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
let hub = A::hub(self);
let scope = PassErrorScope::PassEncoder(encoder_id);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id)
.map_pass_err(PassErrorScope::PassEncoder(encoder_id))?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(scope)?;
let commands = ComputeCommand::resolve_compute_command_ids(A::hub(self), &base.commands)?;

let timestamp_writes = if let Some(tw) = timestamp_writes {
Some(ArcComputePassTimestampWrites {
query_set: hub
.query_sets
.read()
.get_owned(tw.query_set)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(scope)?,
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};

self.compute_pass_end_impl::<A>(
&cmd_buf,
BasePass {
Expand All @@ -382,13 +442,11 @@ impl Global {
&self,
cmd_buf: &CommandBuffer<A>,
base: BasePass<ArcComputeCommand<A>>,
timestamp_writes: Option<&ComputePassTimestampWrites>,
mut timestamp_writes: Option<ArcComputePassTimestampWrites<A>>,
) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let pass_scope = PassErrorScope::Pass(Some(cmd_buf.as_info().id()));

let hub = A::hub(self);

let device = &cmd_buf.device;
if !device.is_valid() {
return Err(ComputePassErrorInner::InvalidDevice(
Expand All @@ -410,7 +468,13 @@ impl Global {
string_data: base.string_data.to_vec(),
push_constant_data: base.push_constant_data.to_vec(),
},
timestamp_writes: timestamp_writes.cloned(),
timestamp_writes: timestamp_writes
.as_ref()
.map(|tw| ComputePassTimestampWrites {
query_set: tw.query_set.as_info().id(),
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
}),
});
}

Expand All @@ -428,8 +492,6 @@ impl Global {
*status = CommandEncoderStatus::Error;
let raw = encoder.open().map_pass_err(pass_scope)?;

let query_set_guard = hub.query_sets.read();

let mut state = State {
binder: Binder::new(),
pipeline: None,
Expand All @@ -441,12 +503,19 @@ impl Global {
let mut string_offset = 0;
let mut active_query = None;

let timestamp_writes = if let Some(tw) = timestamp_writes {
let query_set: &resource::QuerySet<A> = tracker
.query_sets
.add_single(&*query_set_guard, tw.query_set)
.ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(pass_scope)?;
let snatch_guard = device.snatchable_lock.read();

let indices = &device.tracker_indices;
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
tracker.bind_groups.set_size(indices.bind_groups.size());
tracker
.compute_pipelines
.set_size(indices.compute_pipelines.size());
tracker.query_sets.set_size(indices.query_sets.size());

let timestamp_writes = if let Some(tw) = timestamp_writes.take() {
let query_set = tracker.query_sets.insert_single(tw.query_set);

// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
Expand Down Expand Up @@ -476,17 +545,6 @@ impl Global {
None
};

let snatch_guard = device.snatchable_lock.read();

let indices = &device.tracker_indices;
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
tracker.bind_groups.set_size(indices.bind_groups.size());
tracker
.compute_pipelines
.set_size(indices.compute_pipelines.size());
tracker.query_sets.set_size(indices.query_sets.size());

let discard_hal_labels = self
.instance
.flags
Expand Down Expand Up @@ -812,7 +870,6 @@ impl Global {
query_set,
query_index,
} => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::WriteTimestamp;

device
Expand All @@ -822,33 +879,29 @@ impl Global {
let query_set = tracker.query_sets.insert_single(query_set);

query_set
.validate_and_write_timestamp(raw, query_set_id, query_index, None)
.validate_and_write_timestamp(raw, query_index, None)
.map_pass_err(scope)?;
}
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => {
let query_set_id = query_set.as_info().id();
let scope = PassErrorScope::BeginPipelineStatisticsQuery;

let query_set = tracker.query_sets.insert_single(query_set);

query_set
.validate_and_begin_pipeline_statistics_query(
raw,
query_set_id,
query_index,
None,
&mut active_query,
)
.map_pass_err(scope)?;
validate_and_begin_pipeline_statistics_query(
query_set.clone(),
raw,
query_index,
None,
&mut active_query,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;

end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
.map_pass_err(scope)?;
end_pipeline_statistics_query(raw, &mut active_query).map_pass_err(scope)?;
}
}
}
Expand Down Expand Up @@ -919,10 +972,9 @@ impl Global {
let bind_group = hub
.bind_groups
.read()
.get(bind_group_id)
.get_owned(bind_group_id)
.map_err(|_| ComputePassErrorInner::InvalidBindGroup(index))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;

base.commands.push(ArcComputeCommand::SetBindGroup {
index,
Expand Down Expand Up @@ -952,10 +1004,9 @@ impl Global {
let pipeline = hub
.compute_pipelines
.read()
.get(pipeline_id)
.get_owned(pipeline_id)
.map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;

base.commands.push(ArcComputeCommand::SetPipeline(pipeline));

Expand Down Expand Up @@ -1035,10 +1086,9 @@ impl Global {
let buffer = hub
.buffers
.read()
.get(buffer_id)
.get_owned(buffer_id)
.map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;

base.commands
.push(ArcComputeCommand::<A>::DispatchIndirect { buffer, offset });
Expand Down Expand Up @@ -1109,10 +1159,9 @@ impl Global {
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.get_owned(query_set_id)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;

base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set,
Expand All @@ -1135,10 +1184,9 @@ impl Global {
let query_set = hub
.query_sets
.read()
.get(query_set_id)
.get_owned(query_set_id)
.map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id))
.map_pass_err(scope)?
.clone();
.map_pass_err(scope)?;

base.commands
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
Expand Down
2 changes: 2 additions & 0 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ pub enum CommandEncoderError {
Device(#[from] DeviceError),
#[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")]
Locked,
#[error("QuerySet provided for pass timestamp writes is invalid.")]
InvalidTimestampWritesQuerySetId,
}

impl Global {
Expand Down
Loading

0 comments on commit 60180ae

Please sign in to comment.