Skip to content

Commit

Permalink
[msl-out] Replace per_stage_map with per_entry_point_map (#2237)
Browse files Browse the repository at this point in the history
The existing `per_stage_map` field of MSL backend options specifies
resource binding maps that apply to all entry points of each stage type.
It is useful to have the ability to provide a separate binding index map
for each entry point, especially when the same shader module defines
multiple entry points of the same stage kind.

This patch replaces `per_stage_map` with a new `per_entry_point_map`
option where resources are keyed by the entry-point function name.
  • Loading branch information
armansito authored Feb 22, 2023
1 parent 9742f16 commit 00be08e
Show file tree
Hide file tree
Showing 16 changed files with 227 additions and 82 deletions.
70 changes: 31 additions & 39 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ holding the result.
*/

use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use std::{
fmt::{Error as FmtError, Write},
ops,
};
use std::fmt::{Error as FmtError, Write};

mod keywords;
pub mod sampler;
Expand Down Expand Up @@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTar
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageResources {
pub struct EntryPointResources {
pub resources: BindingMap,

pub push_constant_buffer: Option<Slot>,
Expand All @@ -80,26 +77,7 @@ pub struct PerStageResources {
pub sizes_buffer: Option<Slot>,
}

#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageMap {
pub vs: PerStageResources,
pub fs: PerStageResources,
pub cs: PerStageResources,
}

impl ops::Index<crate::ShaderStage> for PerStageMap {
type Output = PerStageResources;
fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
match stage {
crate::ShaderStage::Vertex => &self.vs,
crate::ShaderStage::Fragment => &self.fs,
crate::ShaderStage::Compute => &self.cs,
}
}
}
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;

enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
Expand Down Expand Up @@ -198,8 +176,8 @@ enum LocationMode {
pub struct Options {
/// (Major, Minor) target version of the Metal Shading Language.
pub lang_version: (u8, u8),
/// Map of per-stage resources to slots.
pub per_stage_map: PerStageMap,
/// Map of entry-point resources, indexed by entry point function name, to slots.
pub per_entry_point_map: EntryPointResourceMap,
/// Samplers to be inlined into the code.
pub inline_samplers: Vec<sampler::InlineSampler>,
/// Make it possible to link different stages via SPIRV-Cross.
Expand All @@ -217,7 +195,7 @@ impl Default for Options {
fn default() -> Self {
Options {
lang_version: (2, 0),
per_stage_map: PerStageMap::default(),
per_entry_point_map: EntryPointResourceMap::default(),
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
Expand Down Expand Up @@ -296,12 +274,26 @@ impl Options {
}
}

fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
self.per_entry_point_map.get(&ep.name)
}

fn get_resource_binding_target(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Option<&BindTarget> {
self.get_entry_point_resources(ep)
.and_then(|res| res.resources.get(res_binding))
}

fn resolve_resource_binding(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
match self.per_stage_map[stage].resources.get(res_binding) {
let target = self.get_resource_binding_target(ep, res_binding);
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
Expand All @@ -312,15 +304,13 @@ impl Options {
}
}

const fn resolve_push_constants(
fn resolve_push_constants(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = match stage {
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
};
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand All @@ -340,9 +330,11 @@ impl Options {

fn resolve_sizes_buffer(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self.per_stage_map[stage].sizes_buffer;
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
Expand Down
17 changes: 8 additions & 9 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3406,7 +3406,8 @@ impl<W: Write> Writer<W> {
break;
}
};
let good = match options.per_stage_map[ep.stage].resources.get(br) {
let target = options.get_resource_binding_target(ep, br);
let good = match target {
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
Expand All @@ -3431,7 +3432,7 @@ impl<W: Write> Writer<W> {
}
}
crate::AddressSpace::PushConstant => {
if let Err(e) = options.resolve_push_constants(ep.stage) {
if let Err(e) = options.resolve_push_constants(ep) {
ep_error = Some(e);
break;
}
Expand All @@ -3442,7 +3443,7 @@ impl<W: Write> Writer<W> {
}
}
if supports_array_length {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
if let Err(err) = options.resolve_sizes_buffer(ep) {
ep_error = Some(err);
}
}
Expand Down Expand Up @@ -3711,15 +3712,13 @@ impl<W: Write> Writer<W> {
}
// the resolves have already been checked for `!fake_missing_bindings` case
let resolved = match var.space {
crate::AddressSpace::PushConstant => {
options.resolve_push_constants(ep.stage).ok()
}
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
crate::AddressSpace::WorkGroup => None,
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
return Err(Error::UnsupportedAddressSpace(var.space))
}
_ => options
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
.ok(),
};
if let Some(ref resolved) = resolved {
Expand Down Expand Up @@ -3764,7 +3763,7 @@ impl<W: Write> Writer<W> {
// passed as a final struct-typed argument.
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
let resolved = options.resolve_sizes_buffer(ep).unwrap();
let separator = if module.global_variables.is_empty() {
' '
} else {
Expand Down Expand Up @@ -3824,7 +3823,7 @@ impl<W: Write> Writer<W> {
};
} else if let Some(ref binding) = var.binding {
// write an inline sampler
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
if let Some(sampler) = resolved.as_inline_sampler(options) {
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(
Expand Down
10 changes: 5 additions & 5 deletions tests/in/access.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"foo_vert": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
Expand All @@ -16,20 +16,20 @@
},
sizes_buffer: Some(24),
),
fs: (
"foo_frag": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
},
sizes_buffer: Some(24),
),
cs: (
"atomics": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
},
sizes_buffer: Some(24),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/binding-arrays.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
},
sizes_buffer: None,
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: true,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/bitcast.params.ron
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/bits.param.ron
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/boids.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
},
sizes_buffer: Some(3),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/extra.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
),
msl: (
lang_version: (2, 2),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
push_constant_buffer: Some(1),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
2 changes: 1 addition & 1 deletion tests/in/interface.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
),
msl: (
lang_version: (2, 1),
per_stage_map: (),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
6 changes: 3 additions & 3 deletions tests/in/padding.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"vertex": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
(group: 0, binding: 2): (buffer: Some(2), mutable: false),
},
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
Expand Down
Loading

0 comments on commit 00be08e

Please sign in to comment.