Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait for submissions to complete on Queue drop #6413

Merged
merged 12 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ static DEVICE_DROP_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().expect_fail(FailureCase::webgl2()))
.run_sync(|ctx| {
// This test checks that when the device is dropped (such as in a GC),
// the provided DeviceLostClosure is called with reason DeviceLostReason::Unknown.
// the provided DeviceLostClosure is called with reason DeviceLostReason::Dropped.
// Fails on webgl because webgl doesn't implement drop.
static WAS_CALLED: std::sync::atomic::AtomicBool = AtomicBool::new(false);

Expand All @@ -642,8 +642,7 @@ static DEVICE_DROP_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
});
ctx.device.set_device_lost_callback(callback);

// Drop the device.
drop(ctx.device);
drop(ctx);

assert!(
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
Expand Down Expand Up @@ -676,35 +675,6 @@ static DEVICE_LOST_REPLACED_CALLBACK: GpuTestConfiguration = GpuTestConfiguratio
);
});

#[gpu_test]
static DROPPED_GLOBAL_THEN_DEVICE_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().skip(FailureCase::always()))
.run_sync(|ctx| {
// What we want to do is to drop the Global, forcing a code path that
// eventually calls Device.prepare_to_die, without having first dropped
// the device. This models what might happen in a user agent that kills
// wgpu without providing a more orderly shutdown. In such a case, the
// device lost callback should be invoked with the message "Device is
// dying."
static WAS_CALLED: AtomicBool = AtomicBool::new(false);

// Set a LoseDeviceCallback on the device.
let callback = Box::new(|reason, message| {
WAS_CALLED.store(true, std::sync::atomic::Ordering::SeqCst);
assert_eq!(reason, wgt::DeviceLostReason::Dropped);
assert_eq!(message, "Device is dying.");
});
ctx.device.set_device_lost_callback(callback);

// TODO: Drop the Global, somehow.

// Confirm that the callback was invoked.
assert!(
WAS_CALLED.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});

#[gpu_test]
static DIFFERENT_BGL_ORDER_BW_SHADER_AND_API: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default())
Expand Down
53 changes: 31 additions & 22 deletions wgpu-core/src/command/ray_tracing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
device::queue::TempResource,
device::{queue::TempResource, Device},
global::Global,
hub::Hub,
id::CommandEncoderId,
Expand All @@ -20,10 +20,7 @@ use crate::{
use wgt::{math::align_to, BufferUsages, Features};

use super::CommandBufferMutable;
use crate::device::queue::PendingWrites;
use hal::BufferUses;
use std::mem::ManuallyDrop;
use std::ops::DerefMut;
use std::{
cmp::max,
num::NonZeroU64,
Expand Down Expand Up @@ -184,7 +181,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
device,
)?;

let snatch_guard = device.snatchable_lock.read();
Expand Down Expand Up @@ -248,7 +245,9 @@ impl Global {
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
device.pending_writes.lock().insert_tlas(&tlas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}

cmd_buf_data.tlas_actions.push(TlasAction {
tlas: tlas.clone(),
Expand Down Expand Up @@ -349,10 +348,12 @@ impl Global {
}
}

device
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}

Ok(())
}
Expand Down Expand Up @@ -495,7 +496,7 @@ impl Global {
build_command_index,
&mut buf_storage,
hub,
device.pending_writes.lock().deref_mut(),
device,
)?;

let snatch_guard = device.snatchable_lock.read();
Expand All @@ -516,7 +517,9 @@ impl Global {
.get(package.tlas_id)
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
device.pending_writes.lock().insert_tlas(&tlas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_tlas(&tlas);
}
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());

tlas_lock_store.push((Some(package), tlas.clone()))
Expand Down Expand Up @@ -742,17 +745,21 @@ impl Global {
}

if let Some(staging_buffer) = staging_buffer {
device
.pending_writes
.lock()
.consume_temp(TempResource::StagingBuffer(staging_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::StagingBuffer(staging_buffer));
}
}
}

device
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
if let Some(queue) = device.get_queue() {
queue
.pending_writes
.lock()
.consume_temp(TempResource::ScratchBuffer(scratch_buffer));
}

Ok(())
}
Expand Down Expand Up @@ -839,7 +846,7 @@ fn iter_blas<'a>(
build_command_index: NonZeroU64,
buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub,
pending_writes: &mut ManuallyDrop<PendingWrites>,
device: &Device,
) -> Result<(), BuildAccelerationStructureError> {
let mut temp_buffer = Vec::new();
for entry in blas_iter {
Expand All @@ -849,7 +856,9 @@ fn iter_blas<'a>(
.get()
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
pending_writes.insert_blas(&blas);
if let Some(queue) = device.get_queue() {
queue.pending_writes.lock().insert_blas(&blas);
}

cmd_buf_data.blas_actions.push(BlasAction {
blas: blas.clone(),
Expand Down
37 changes: 13 additions & 24 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,11 @@ impl Global {
device.check_is_valid()?;
buffer.check_usage(wgt::BufferUsages::MAP_WRITE)?;

let last_submission = device
.lock_life()
.get_buffer_latest_submission_index(&buffer);
let last_submission = device.get_queue().and_then(|queue| {
queue
.lock_life()
.get_buffer_latest_submission_index(&buffer)
});

if let Some(last_submission) = last_submission {
device.wait_for_submit(last_submission)?;
Expand Down Expand Up @@ -2078,20 +2080,7 @@ impl Global {
profiling::scope!("Device::drop");
api_log!("Device::drop {device_id:?}");

let device = self.hub.devices.remove(device_id);
let device_lost_closure = device.lock_life().device_lost_closure.take();
if let Some(closure) = device_lost_closure {
closure.call(DeviceLostReason::Dropped, String::from("Device dropped."));
}

// The things `Device::prepare_to_die` takes care are mostly
// unnecessary here. We know our queue is empty, so we don't
// need to wait for submissions or triage them. We know we were
// just polled, so `life_tracker.free_resources` is empty.
debug_assert!(device.lock_life().queue_empty());
device.pending_writes.lock().deactivate();

drop(device);
self.hub.devices.remove(device_id);
}

// This closure will be called exactly once during "lose the device",
Expand All @@ -2103,14 +2092,14 @@ impl Global {
) {
let device = self.hub.devices.get(device_id);

let mut life_tracker = device.lock_life();
if let Some(existing_closure) = life_tracker.device_lost_closure.take() {
// It's important to not hold the lock while calling the closure.
drop(life_tracker);
existing_closure.call(DeviceLostReason::ReplacedCallback, "".to_string());
life_tracker = device.lock_life();
let old_device_lost_closure = device
.device_lost_closure
.lock()
.replace(device_lost_closure);

if let Some(old_device_lost_closure) = old_device_lost_closure {
old_device_lost_closure.call(DeviceLostReason::ReplacedCallback, "".to_string());
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
}
life_tracker.device_lost_closure = Some(device_lost_closure);
}

pub fn device_destroy(&self, device_id: DeviceId) {
Expand Down
70 changes: 6 additions & 64 deletions wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::{
device::{
queue::{EncoderInFlight, SubmittedWorkDoneClosure, TempResource},
DeviceError, DeviceLostClosure,
DeviceError,
},
resource::{self, Buffer, Texture, Trackable},
resource::{Buffer, Texture, Trackable},
snatch::SnatchGuard,
SubmissionIndex,
};
Expand Down Expand Up @@ -196,11 +196,6 @@ pub(crate) struct LifetimeTracker {
/// must happen _after_ all mapped buffer callbacks are mapped, so we defer them
/// here until the next time the device is maintained.
work_done_closures: SmallVec<[SubmittedWorkDoneClosure; 1]>,

/// Closure to be called on "lose the device". This is invoked directly by
/// device.lose or by the UserCallbacks returned from maintain when the device
/// has been destroyed and its queues are empty.
pub device_lost_closure: Option<DeviceLostClosure>,
}

impl LifetimeTracker {
Expand All @@ -209,7 +204,6 @@ impl LifetimeTracker {
active: Vec::new(),
ready_to_map: Vec::new(),
work_done_closures: SmallVec::new(),
device_lost_closure: None,
}
}

Expand Down Expand Up @@ -394,7 +388,6 @@ impl LifetimeTracker {
#[must_use]
pub(crate) fn handle_mapping(
&mut self,
raw: &dyn hal::DynDevice,
snatch_guard: &SnatchGuard,
) -> Vec<super::BufferMapPendingClosure> {
if self.ready_to_map.is_empty() {
Expand All @@ -404,61 +397,10 @@ impl LifetimeTracker {
Vec::with_capacity(self.ready_to_map.len());

for buffer in self.ready_to_map.drain(..) {
// This _cannot_ be inlined into the match. If it is, the lock will be held
// open through the whole match, resulting in a deadlock when we try to re-lock
// the buffer back to active.
let mapping = std::mem::replace(
&mut *buffer.map_state.lock(),
resource::BufferMapState::Idle,
);
let pending_mapping = match mapping {
resource::BufferMapState::Waiting(pending_mapping) => pending_mapping,
// Mapping cancelled
resource::BufferMapState::Idle => continue,
// Mapping queued at least twice by map -> unmap -> map
// and was already successfully mapped below
resource::BufferMapState::Active { .. } => {
*buffer.map_state.lock() = mapping;
continue;
}
_ => panic!("No pending mapping."),
};
let status = if pending_mapping.range.start != pending_mapping.range.end {
let host = pending_mapping.op.host;
let size = pending_mapping.range.end - pending_mapping.range.start;
match super::map_buffer(
raw,
&buffer,
pending_mapping.range.start,
size,
host,
snatch_guard,
) {
Ok(mapping) => {
*buffer.map_state.lock() = resource::BufferMapState::Active {
mapping,
range: pending_mapping.range.clone(),
host,
};
Ok(())
}
Err(e) => {
log::error!("Mapping failed: {e}");
Err(e)
}
}
} else {
*buffer.map_state.lock() = resource::BufferMapState::Active {
mapping: hal::BufferMapping {
ptr: std::ptr::NonNull::dangling(),
is_coherent: true,
},
range: pending_mapping.range,
host: pending_mapping.op.host,
};
Ok(())
};
pending_callbacks.push((pending_mapping.op, status));
match buffer.map(snatch_guard) {
Some(cb) => pending_callbacks.push(cb),
None => continue,
}
}
pending_callbacks
}
Expand Down
11 changes: 6 additions & 5 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,25 @@ impl DeviceLostClosure {
}
}

fn map_buffer(
raw: &dyn hal::DynDevice,
pub(crate) fn map_buffer(
buffer: &Buffer,
offset: BufferAddress,
size: BufferAddress,
kind: HostMap,
snatch_guard: &SnatchGuard,
) -> Result<hal::BufferMapping, BufferAccessError> {
let raw_device = buffer.device.raw();
let raw_buffer = buffer.try_raw(snatch_guard)?;
let mapping = unsafe {
raw.map_buffer(raw_buffer, offset..offset + size)
raw_device
.map_buffer(raw_buffer, offset..offset + size)
.map_err(|e| buffer.device.handle_hal_error(e))?
};

if !mapping.is_coherent && kind == HostMap::Read {
#[allow(clippy::single_range_in_vec_init)]
unsafe {
raw.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
raw_device.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
}
}

Expand Down Expand Up @@ -370,7 +371,7 @@ fn map_buffer(
&& kind == HostMap::Read
&& buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
{
unsafe { raw.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
unsafe { raw_device.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
}
}
}
Expand Down
Loading