Skip to content

Commit

Permalink
Vulkan Sampler Cache
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Jan 6, 2025
1 parent d291571 commit 8209198
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 12 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ noise = { version = "0.8", git = "https://github.com/Razaekel/noise-rs.git", rev
nv-flip = "0.1"
obj = "0.10"
once_cell = "1.20.2"
# Firefox has 3.4.0 vendored, so we allow that version in our dependencies
ordered-float = ">=3,<=4.6"
parking_lot = "0.12.1"
pico-args = { version = "0.5.0", features = [
"eq-separator",
Expand Down
2 changes: 2 additions & 0 deletions wgpu-hal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ vulkan = [
"dep:libloading",
"dep:smallvec",
"dep:android_system_properties",
"dep:ordered-float",
]
gles = [
"naga/glsl-out",
Expand Down Expand Up @@ -125,6 +126,7 @@ profiling = { workspace = true, default-features = false }
raw-window-handle.workspace = true
thiserror.workspace = true
once_cell.workspace = true
ordered-float = { workspace = true, optional = true }

# backends common
arrayvec.workspace = true
Expand Down
7 changes: 7 additions & 0 deletions wgpu-hal/src/vulkan/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,10 @@ impl super::Instance {
.is_some_and(|ext| ext.shader_zero_initialize_workgroup_memory == vk::TRUE),
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
|| phd_capabilities.supports_extension(khr::image_format_list::NAME),
maximum_samplers: phd_capabilities
.properties
.limits
.max_sampler_allocation_count,
};
let capabilities = crate::Capabilities {
limits: phd_capabilities.to_wgpu_limits(),
Expand Down Expand Up @@ -1907,6 +1911,9 @@ impl super::Adapter {
workarounds: self.workarounds,
render_passes: Mutex::new(Default::default()),
framebuffers: Mutex::new(Default::default()),
sampler_cache: Mutex::new(super::sampler::SamplerCache::new(
self.private_caps.maximum_samplers,
)),
memory_allocations_counter: Default::default(),
});

Expand Down
30 changes: 18 additions & 12 deletions wgpu-hal/src/vulkan/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ impl crate::Device for super::Device {
&self,
desc: &crate::SamplerDescriptor,
) -> Result<super::Sampler, crate::DeviceError> {
let mut vk_info = vk::SamplerCreateInfo::default()
let mut create_info = vk::SamplerCreateInfo::default()
.flags(vk::SamplerCreateFlags::empty())
.mag_filter(conv::map_filter_mode(desc.mag_filter))
.min_filter(conv::map_filter_mode(desc.min_filter))
Expand All @@ -1316,40 +1316,46 @@ impl crate::Device for super::Device {
.max_lod(desc.lod_clamp.end);

if let Some(fun) = desc.compare {
vk_info = vk_info
create_info = create_info
.compare_enable(true)
.compare_op(conv::map_comparison(fun));
}

if desc.anisotropy_clamp != 1 {
// We only enable anisotropy if it is supported, and wgpu-hal interface guarantees
// the clamp is in the range [1, 16] which is always supported if anisotropy is.
vk_info = vk_info
create_info = create_info
.anisotropy_enable(true)
.max_anisotropy(desc.anisotropy_clamp as f32);
}

if let Some(color) = desc.border_color {
vk_info = vk_info.border_color(conv::map_border_color(color));
create_info = create_info.border_color(conv::map_border_color(color));
}

let raw = unsafe {
self.shared
.raw
.create_sampler(&vk_info, None)
.map_err(super::map_host_device_oom_and_ioca_err)?
};
let raw = self
.shared
.sampler_cache
.lock()
.create_sampler(&self.shared.raw, create_info)?;

// Note: Cached samplers will just continually overwrite the label
//
// https://github.com/gfx-rs/wgpu/issues/6867
if let Some(label) = desc.label {
unsafe { self.shared.set_object_name(raw, label) };
}

self.counters.samplers.add(1);

Ok(super::Sampler { raw })
Ok(super::Sampler { raw, create_info })
}
unsafe fn destroy_sampler(&self, sampler: super::Sampler) {
unsafe { self.shared.raw.destroy_sampler(sampler.raw, None) };
self.shared.sampler_cache.lock().destroy_sampler(
&self.shared.raw,
sampler.create_info,
sampler.raw,
);

self.counters.samplers.sub(1);
}
Expand Down
4 changes: 4 additions & 0 deletions wgpu-hal/src/vulkan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod command;
mod conv;
mod device;
mod instance;
mod sampler;

use std::{
borrow::Borrow,
Expand Down Expand Up @@ -532,6 +533,7 @@ struct PrivateCapabilities {
robust_image_access2: bool,
zero_initialize_workgroup_memory: bool,
image_format_list: bool,
maximum_samplers: u32,
}

bitflags::bitflags!(
Expand Down Expand Up @@ -641,6 +643,7 @@ struct DeviceShared {
features: wgt::Features,
render_passes: Mutex<rustc_hash::FxHashMap<RenderPassKey, vk::RenderPass>>,
framebuffers: Mutex<rustc_hash::FxHashMap<FramebufferKey, vk::Framebuffer>>,
sampler_cache: Mutex<sampler::SamplerCache>,
memory_allocations_counter: InternalCounter,
}

Expand Down Expand Up @@ -828,6 +831,7 @@ impl TextureView {
#[derive(Debug)]
pub struct Sampler {
raw: vk::Sampler,
create_info: vk::SamplerCreateInfo<'static>,
}

impl crate::DynSampler for Sampler {}
Expand Down
178 changes: 178 additions & 0 deletions wgpu-hal/src/vulkan/sampler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//! Sampler cache for Vulkan backend.
//!
//! Nearly identical to the DX12 sampler cache, without descriptor heap management.
use std::collections::{hash_map::Entry, HashMap};

use ash::vk;
use ordered_float::OrderedFloat;

/// If the allowed sampler count is above this value, the sampler cache is disabled.
const ENABLE_SAMPLER_CACHE_CUTOFF: u32 = 1 << 20;

/// [`vk::SamplerCreateInfo`] is not hashable, so we wrap it in a newtype that is.
///
/// We use [`OrderedFloat`] to allow for floating point values to be compared and
/// hashed in a defined way.
#[derive(Copy, Clone)]
struct HashableSamplerCreateInfo(vk::SamplerCreateInfo<'static>);

impl PartialEq for HashableSamplerCreateInfo {
fn eq(&self, other: &Self) -> bool {
self.0.flags == other.0.flags
&& self.0.mag_filter == other.0.mag_filter
&& self.0.min_filter == other.0.min_filter
&& self.0.mipmap_mode == other.0.mipmap_mode
&& self.0.address_mode_u == other.0.address_mode_u
&& self.0.address_mode_v == other.0.address_mode_v
&& self.0.address_mode_w == other.0.address_mode_w
&& OrderedFloat(self.0.mip_lod_bias) == OrderedFloat(other.0.mip_lod_bias)
&& self.0.anisotropy_enable == other.0.anisotropy_enable
&& OrderedFloat(self.0.max_anisotropy) == OrderedFloat(other.0.max_anisotropy)
&& self.0.compare_enable == other.0.compare_enable
&& self.0.compare_op == other.0.compare_op
&& OrderedFloat(self.0.min_lod) == OrderedFloat(other.0.min_lod)
&& OrderedFloat(self.0.max_lod) == OrderedFloat(other.0.max_lod)
&& self.0.border_color == other.0.border_color
&& self.0.unnormalized_coordinates == other.0.unnormalized_coordinates
}
}

impl Eq for HashableSamplerCreateInfo {}

impl std::hash::Hash for HashableSamplerCreateInfo {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.flags.hash(state);
self.0.mag_filter.hash(state);
self.0.min_filter.hash(state);
self.0.mipmap_mode.hash(state);
self.0.address_mode_u.hash(state);
self.0.address_mode_v.hash(state);
self.0.address_mode_w.hash(state);
OrderedFloat(self.0.mip_lod_bias).hash(state);
self.0.anisotropy_enable.hash(state);
OrderedFloat(self.0.max_anisotropy).hash(state);
self.0.compare_enable.hash(state);
self.0.compare_op.hash(state);
OrderedFloat(self.0.min_lod).hash(state);
OrderedFloat(self.0.max_lod).hash(state);
self.0.border_color.hash(state);
self.0.unnormalized_coordinates.hash(state);
}
}

/// Entry in the sampler cache.
struct CacheEntry {
sampler: vk::Sampler,
ref_count: u32,
}

/// Global sampler cache.
///
/// As some devices have a low limit (4000) on the number of unique samplers that can be created,
/// we need to cache samplers to avoid running out if people eagerly create duplicate samplers.
pub(crate) struct SamplerCache {
/// Mapping from the sampler description to sampler and reference count.
samplers: HashMap<HashableSamplerCreateInfo, CacheEntry>,
/// Maximum number of unique samplers that can be created.
total_capacity: u32,
/// If true, the sampler cache is disabled and all samplers are created on demand.
passthrough: bool,
}

impl SamplerCache {
pub fn new(total_capacity: u32) -> Self {
let passthrough = total_capacity >= ENABLE_SAMPLER_CACHE_CUTOFF;
Self {
samplers: HashMap::new(),
total_capacity,
passthrough,
}
}

/// Create a sampler, or return an existing one if it already exists.
///
/// If the sampler already exists, the reference count is incremented.
///
/// If the sampler does not exist, a new sampler is created and inserted into the cache.
///
/// If the cache is full, an error is returned.
pub fn create_sampler(
&mut self,
device: &ash::Device,
create_info: vk::SamplerCreateInfo<'static>,
) -> Result<vk::Sampler, crate::DeviceError> {
if self.passthrough {
return unsafe { device.create_sampler(&create_info, None) }
.map_err(super::map_host_device_oom_and_ioca_err);
};

// Get the number of used samplers. Needs to be done before to appease the borrow checker.
let used_samplers = self.samplers.len();

match self.samplers.entry(HashableSamplerCreateInfo(create_info)) {
Entry::Occupied(occupied_entry) => {
// We have found a match, so increment the refcount and return the index.
let value = occupied_entry.into_mut();
value.ref_count += 1;
Ok(value.sampler)
}
Entry::Vacant(vacant_entry) => {
// We need to create a new sampler.

// We need to check if we can create more samplers.
if used_samplers >= self.total_capacity as usize {
log::error!("There is no more room in the global sampler heap for more unique samplers. Your device supports a maximum of {} unique samplers.", self.samplers.len());
return Err(crate::DeviceError::OutOfMemory);
}

// Create the sampler.
let sampler = unsafe { device.create_sampler(&create_info, None) }
.map_err(super::map_host_device_oom_and_ioca_err)?;

// Insert the new sampler into the mapping.
vacant_entry.insert(CacheEntry {
sampler,
ref_count: 1,
});

Ok(sampler)
}
}
}

/// Decrease the reference count of a sampler and destroy it if the reference count reaches 0.
///
/// The provided sampler is checked against the sampler in the cache to ensure there is no clerical error.
pub fn destroy_sampler(
&mut self,
device: &ash::Device,
create_info: vk::SamplerCreateInfo<'static>,
provided_sampler: vk::Sampler,
) {
if self.passthrough {
unsafe { device.destroy_sampler(provided_sampler, None) };
return;
};

let Entry::Occupied(mut hash_map_entry) =
self.samplers.entry(HashableSamplerCreateInfo(create_info))
else {
log::error!("Trying to destroy a sampler that does not exist.");
return;
};
let cache_entry = hash_map_entry.get_mut();

assert_eq!(
cache_entry.sampler, provided_sampler,
"Provided sampler does not match the sampler in the cache."
);

cache_entry.ref_count -= 1;

if cache_entry.ref_count == 0 {
unsafe { device.destroy_sampler(cache_entry.sampler, None) };
hash_map_entry.remove();
}
}
}

0 comments on commit 8209198

Please sign in to comment.