Skip to content

Commit

Permalink
Don't Forget a Layer of Indirection
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Dec 18, 2024
1 parent 713fc6d commit 1564ef7
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 118 deletions.
17 changes: 9 additions & 8 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1181,14 +1181,14 @@ impl<W: Write> super::Writer<'_, W> {
writeln!(
self.out,
"SamplerState {}[2048]: register(s{}, space{});",
super::writer::SAMPLER_BUFFER_VAR,
super::writer::SAMPLER_HEAP_VAR,
self.options.sampler_array.sampler_array.register,
self.options.sampler_array.sampler_array.space
)?;
writeln!(
self.out,
"SamplerComparisonState {}[2048]: register(s{}, space{});",
super::writer::COMPARISON_SAMPLER_BUFFER_VAR,
super::writer::COMPARISON_SAMPLER_HEAP_VAR,
self.options.sampler_array.comparison_sampler_array.register,
self.options.sampler_array.comparison_sampler_array.space
)?;
Expand All @@ -1202,12 +1202,15 @@ impl<W: Write> super::Writer<'_, W> {
&mut self,
key: super::SamplerBufferKey,
) -> BackendResult {
let entry = self.wrapped.sampler_buffers.entry(key);
let std::collections::hash_map::Entry::Vacant(entry) = entry else {
if self.wrapped.sampler_buffers.contains_key(&key) {
return Ok(());
};

let sampler_array_name = self.namer.call("sampler_array");
self.write_sampler_arrays()?;

let sampler_array_name = self
.namer
.call(&format!("nagaGroup{}SamplerIndexArray", key.group));

let bind_target = match self.options.sampler_index_arrays.get(&key) {
Some(&bind_target) => bind_target,
Expand All @@ -1227,9 +1230,7 @@ impl<W: Write> super::Writer<'_, W> {
bind_target.register, bind_target.space
)?;

entry.insert(sampler_array_name);

self.write_sampler_arrays()?;
self.wrapped.sampler_buffers.insert(key, sampler_array_name);

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,8 @@ pub const RESERVED: &[&str] = &[
super::writer::FREXP_FUNCTION,
super::writer::EXTRACT_BITS_FUNCTION,
super::writer::INSERT_BITS_FUNCTION,
super::writer::SAMPLER_BUFFER_VAR,
super::writer::COMPARISON_SAMPLER_BUFFER_VAR,
super::writer::SAMPLER_HEAP_VAR,
super::writer::COMPARISON_SAMPLER_HEAP_VAR,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ impl Options {
res_binding: &crate::ResourceBinding,
) -> Result<BindTarget, EntryPointError> {
match self.binding_map.get(res_binding) {
Some(target) => Ok(target.clone()),
Some(target) => Ok(*target),
None if self.fake_missing_bindings => Ok(BindTarget {
space: res_binding.group as u8,
register: res_binding.binding,
binding_array_size: None,
}),
None => Err(EntryPointError::MissingBinding(res_binding.clone())),
None => Err(EntryPointError::MissingBinding(*res_binding)),
}
}
}
Expand Down
88 changes: 55 additions & 33 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const EXTRACT_BITS_FUNCTION: &str = "naga_extractBits";
pub(crate) const INSERT_BITS_FUNCTION: &str = "naga_insertBits";
pub(crate) const SAMPLER_BUFFER_VAR: &str = "nagaSamplerArray";
pub(crate) const COMPARISON_SAMPLER_BUFFER_VAR: &str = "nagaComparisonSamplerArray";
pub(crate) const SAMPLER_HEAP_VAR: &str = "nagaSamplerHeap";
pub(crate) const COMPARISON_SAMPLER_HEAP_VAR: &str = "nagaComparisonSamplerHeap";

struct EpStructMember {
name: String,
Expand Down Expand Up @@ -141,11 +141,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
) {
use crate::Expression;
self.need_bake_expressions.clear();
for (fun_handle, expr) in func.expressions.iter() {
let expr_info = &info[fun_handle];
let min_ref_count = func.expressions[fun_handle].bake_ref_count();
for (exp_handle, expr) in func.expressions.iter() {
let expr_info = &info[exp_handle];
let min_ref_count = func.expressions[exp_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(fun_handle);
self.need_bake_expressions.insert(exp_handle);
}

if let Expression::Math { fun, arg, .. } = *expr {
Expand All @@ -170,7 +170,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::CountLeadingZeros => {
let inner = info[fun_handle].ty.inner_with(&module.types);
let inner = info[exp_handle].ty.inner_with(&module.types);
if let Some(ScalarKind::Sint) = inner.scalar_kind() {
self.need_bake_expressions.insert(arg);
}
Expand All @@ -185,6 +185,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.need_bake_expressions.insert(expr);
}
}

if let Expression::GlobalVariable(_) = *expr {
let inner = info[exp_handle].ty.inner_with(&module.types);

if let TypeInner::Sampler { .. } = *inner {
self.need_bake_expressions.insert(exp_handle);
}
}
}
for statement in func.body.iter() {
match *statement {
Expand Down Expand Up @@ -817,10 +825,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
_ => inner,
};

let is_sampler = matches!(*handle_ty, crate::TypeInner::Sampler { .. });
let is_sampler = matches!(*handle_ty, TypeInner::Sampler { .. });

if is_sampler {
self.write_sampler(handle, global)?;
self.write_global_sampler(module, handle, global)?;
return Ok(());
}

Expand Down Expand Up @@ -960,23 +968,48 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn write_sampler(
fn write_global_sampler(
&mut self,
module: &Module,
handle: Handle<crate::GlobalVariable>,
global: &crate::GlobalVariable,
) -> BackendResult {
let name = &self.names[&NameKey::GlobalVariable(handle)];
let binding = *global.binding.as_ref().unwrap();

let binding = global.binding.as_ref().unwrap();
let key = super::SamplerBufferKey {
group: binding.group,
};
self.write_wrapped_sampler_buffer(key)?;

// this was already resolved earlier when we started evaluating an entry point.
let bt = self.options.resolve_resource_binding(binding).unwrap();
let bt = self.options.resolve_resource_binding(&binding).unwrap();

writeln!(self.out, "static const uint {name} = {};", bt.register)?;
let comparison = match module.types[global.ty].inner {
TypeInner::Sampler { comparison } => comparison,
TypeInner::BindingArray { .. } => {
// We don't emit anything for binding arrays immediately,
// as we need to do the index lookup just-in-time.
return Ok(());
}
_ => unreachable!(),
};

self.write_wrapped_sampler_buffer(super::SamplerBufferKey {
group: binding.group,
})?;
write!(self.out, "static const ")?;
self.write_type(module, global.ty)?;

let heap_var = if comparison {
COMPARISON_SAMPLER_HEAP_VAR
} else {
SAMPLER_HEAP_VAR
};

let index_buffer_name = &self.wrapped.sampler_buffers[&key];
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(
self.out,
" {name} = {heap_var}[{index_buffer_name}[{register}]];",
register = bt.register
)?;

Ok(())
}
Expand Down Expand Up @@ -2917,23 +2950,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
Expression::GlobalVariable(handle) => {
let global_variable = &module.global_variables[handle];
let ty = &module.types[global_variable.ty].inner;

if let crate::TypeInner::Sampler { comparison } = *ty {
let variable = if comparison {
COMPARISON_SAMPLER_BUFFER_VAR
} else {
SAMPLER_BUFFER_VAR
};
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{variable}[{name}]")?;
} else {
match global_variable.space {
crate::AddressSpace::Storage { .. } => {}
_ => {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{name}")?;
}
match global_variable.space {
crate::AddressSpace::Storage { .. } => {}
_ => {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{name}")?;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl Options {
index: 0,
interpolation: None,
}),
None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
None => Err(EntryPointError::MissingBindTarget(*res_binding)),
}
}

Expand Down
2 changes: 1 addition & 1 deletion naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5149,7 +5149,7 @@ template <typename A>
};
if !good {
ep_error =
Some(super::EntryPointError::MissingBindTarget(br.clone()));
Some(super::EntryPointError::MissingBindTarget(*br));
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ pub enum Binding {
}

/// Pipeline binding information for global resources.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
Expand Down
2 changes: 1 addition & 1 deletion naga/src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ impl super::Validator {
}

if let Some(ref bind) = var.binding {
if !self.ep_resource_bindings.insert(bind.clone()) {
if !self.ep_resource_bindings.insert(*bind) {
if self.flags.contains(super::ValidationFlags::BINDINGS) {
return Err(EntryPointError::BindingCollision(var_handle)
.with_span_handle(var_handle, &module.global_variables));
Expand Down
8 changes: 3 additions & 5 deletions naga/tests/out/hlsl/binding-arrays.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ Texture2DArray<float4> texture_array_2darray[5] : register(t0, space2);
Texture2DMS<float4> texture_array_multisampled[5] : register(t0, space3);
Texture2D<float> texture_array_depth[5] : register(t0, space4);
RWTexture2D<float4> texture_array_storage[5] : register(u0, space5);
static const uint samp = 0;
StructuredBuffer<uint> sampler_array : register(t0, space255);
SamplerState nagaSamplerArray[2048]: register(s0, space0);
SamplerComparisonState nagaComparisonSamplerArray[2048]: register(s0, space1);
static const uint samp_comp = 0;
SamplerState nagaSamplerHeap[2048]: register(s0, space0);
SamplerComparisonState nagaComparisonSamplerHeap[2048]: register(s0, space1);
StructuredBuffer<uint> nagaGroup0SamplerIndexArray : register(t0, space255);
cbuffer uni : register(b0, space8) { UniformIndex uni; }

struct FragmentInput_main {
Expand Down
Loading

0 comments on commit 1564ef7

Please sign in to comment.