Skip to content

Commit

Permalink
Naga Generate Bindless Samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Dec 18, 2024
1 parent 0fe2034 commit 713fc6d
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 63 deletions.
61 changes: 61 additions & 0 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,67 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

pub(super) fn write_sampler_arrays(&mut self) -> BackendResult {
if self.wrapped.sampler_array {
return Ok(());
}

writeln!(
self.out,
"SamplerState {}[2048]: register(s{}, space{});",
super::writer::SAMPLER_BUFFER_VAR,
self.options.sampler_array.sampler_array.register,
self.options.sampler_array.sampler_array.space
)?;
writeln!(
self.out,
"SamplerComparisonState {}[2048]: register(s{}, space{});",
super::writer::COMPARISON_SAMPLER_BUFFER_VAR,
self.options.sampler_array.comparison_sampler_array.register,
self.options.sampler_array.comparison_sampler_array.space
)?;

self.wrapped.sampler_array = true;

Ok(())
}

pub(super) fn write_wrapped_sampler_buffer(
&mut self,
key: super::SamplerBufferKey,
) -> BackendResult {
let entry = self.wrapped.sampler_buffers.entry(key);
let std::collections::hash_map::Entry::Vacant(entry) = entry else {
return Ok(());
};

let sampler_array_name = self.namer.call("sampler_array");

let bind_target = match self.options.sampler_index_arrays.get(&key) {
Some(&bind_target) => bind_target,
None if self.options.fake_missing_bindings => super::BindTarget {
space: u8::MAX,
register: key.group,
binding_array_size: None,
},
None => {
unreachable!("Sampler buffer not bound to a register");
}
};

writeln!(
self.out,
"StructuredBuffer<uint> {sampler_array_name} : register(t{}, space{});",
bind_target.register, bind_target.space
)?;

entry.insert(sampler_array_name);

self.write_sampler_arrays()?;

Ok(())
}

pub(super) fn write_texture_coordinates(
&mut self,
kind: &str,
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,8 @@ pub const RESERVED: &[&str] = &[
super::writer::FREXP_FUNCTION,
super::writer::EXTRACT_BITS_FUNCTION,
super::writer::INSERT_BITS_FUNCTION,
super::writer::SAMPLER_BUFFER_VAR,
super::writer::COMPARISON_SAMPLER_BUFFER_VAR,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
43 changes: 42 additions & 1 deletion naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ use thiserror::Error;

use crate::{back, proc};

#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct BindTarget {
Expand Down Expand Up @@ -178,6 +178,39 @@ impl crate::ImageDimension {
}
}

#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct SamplerBufferKey {
group: u32,
}

#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct SamplerBufferBindTargets {
sampler_array: BindTarget,
comparison_sampler_array: BindTarget,
}

impl Default for SamplerBufferBindTargets {
fn default() -> Self {
Self {
sampler_array: BindTarget {
space: 0,
register: 0,
binding_array_size: None,
},
comparison_sampler_array: BindTarget {
space: 1,
register: 0,
binding_array_size: None,
},
}
}
}

/// Shorthand result used internally by the backend
type BackendResult = Result<(), Error>;

Expand Down Expand Up @@ -206,6 +239,10 @@ pub struct Options {
pub special_constants_binding: Option<BindTarget>,
/// Bind target of the push constant buffer
pub push_constants_target: Option<BindTarget>,
/// Bind target of the sampler array.
pub sampler_array: SamplerBufferBindTargets,
/// Group index -> bind target for each sampler buffer's bind_location
pub sampler_index_arrays: std::collections::BTreeMap<SamplerBufferKey, BindTarget>,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
/// Should we restrict indexing of vectors, matrices and arrays?
Expand All @@ -219,6 +256,8 @@ impl Default for Options {
binding_map: BindingMap::default(),
fake_missing_bindings: true,
special_constants_binding: None,
sampler_array: SamplerBufferBindTargets::default(),
sampler_index_arrays: std::collections::BTreeMap::default(),
push_constants_target: None,
zero_initialize_workgroup_memory: true,
restrict_indexing: true,
Expand Down Expand Up @@ -278,6 +317,8 @@ struct Wrapped {
struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
math: crate::FastHashSet<help::WrappedMath>,
sampler_array: bool,
sampler_buffers: crate::FastHashMap<SamplerBufferKey, String>,
}

impl Wrapped {
Expand Down
66 changes: 55 additions & 11 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
pub(crate) const SAMPLER_BUFFER_VAR: &str = "nagaSamplerArray";
pub(crate) const COMPARISON_SAMPLER_BUFFER_VAR: &str = "nagaComparisonSamplerArray";

struct EpStructMember {
name: String,
Expand Down Expand Up @@ -810,6 +812,18 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

let handle_ty = match *inner {
TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
_ => inner,
};

let is_sampler = matches!(*handle_ty, crate::TypeInner::Sampler { .. });

if is_sampler {
self.write_sampler(handle, global)?;
return Ok(());
}

// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
let register_ty = match global.space {
crate::AddressSpace::Function => unreachable!("Function address space"),
Expand Down Expand Up @@ -839,13 +853,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
register
}
crate::AddressSpace::Handle => {
let handle_ty = match *inner {
TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
_ => inner,
};

let register = match *handle_ty {
TypeInner::Sampler { .. } => "s",
// all storage textures are UAV, unconditionally
TypeInner::Image {
class: crate::ImageClass::Storage { .. },
Expand Down Expand Up @@ -952,6 +960,27 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn write_sampler(
&mut self,
handle: Handle<crate::GlobalVariable>,
global: &crate::GlobalVariable,
) -> BackendResult {
let name = &self.names[&NameKey::GlobalVariable(handle)];

let binding = global.binding.as_ref().unwrap();

// this was already resolved earlier when we started evaluating an entry point.
let bt = self.options.resolve_resource_binding(binding).unwrap();

writeln!(self.out, "static const uint {name} = {};", bt.register)?;

self.write_wrapped_sampler_buffer(super::SamplerBufferKey {
group: binding.group,
})?;

Ok(())
}

/// Helper method used to write global constants
///
/// # Notes
Expand Down Expand Up @@ -2886,13 +2915,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ".x")?;
}
}
Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
crate::AddressSpace::Storage { .. } => {}
_ => {
Expression::GlobalVariable(handle) => {
let global_variable = &module.global_variables[handle];
let ty = &module.types[global_variable.ty].inner;

if let crate::TypeInner::Sampler { comparison } = *ty {
let variable = if comparison {
COMPARISON_SAMPLER_BUFFER_VAR
} else {
SAMPLER_BUFFER_VAR
};
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{name}")?;
write!(self.out, "{variable}[{name}]")?;
} else {
match global_variable.space {
crate::AddressSpace::Storage { .. } => {}
_ => {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{name}")?;
}
}
}
},
}
Expression::LocalVariable(handle) => {
write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
}
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/in/skybox.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
(group: 0, binding: 2): (space: 1, register: 0),
},
fake_missing_bindings: false,
sampler_index_arrays: {
(group: 0): (space: 2, register: 0),
},
special_constants_binding: Some((space: 0, register: 1)),
zero_initialize_workgroup_memory: true,
restrict_indexing: true
Expand Down
7 changes: 5 additions & 2 deletions naga/tests/out/hlsl/binding-arrays.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ Texture2DArray<float4> texture_array_2darray[5] : register(t0, space2);
Texture2DMS<float4> texture_array_multisampled[5] : register(t0, space3);
Texture2D<float> texture_array_depth[5] : register(t0, space4);
RWTexture2D<float4> texture_array_storage[5] : register(u0, space5);
SamplerState samp[5] : register(s0, space6);
SamplerComparisonState samp_comp[5] : register(s0, space7);
static const uint samp = 0;
StructuredBuffer<uint> sampler_array : register(t0, space255);
SamplerState nagaSamplerArray[2048]: register(s0, space0);
SamplerComparisonState nagaComparisonSamplerArray[2048]: register(s0, space1);
static const uint samp_comp = 0;
cbuffer uni : register(b0, space8) { UniformIndex uni; }

struct FragmentInput_main {
Expand Down
Loading

0 comments on commit 713fc6d

Please sign in to comment.