Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workgroup uniform load #2201

Merged
merged 15 commits into from
May 26, 2023
6 changes: 6 additions & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ impl StatementGraph {
}
"Atomic"
}
S::WorkGroupUniformLoad { pointer, result } => {
self.emits.push((id, result));
self.dependencies.push((id, pointer, "pointer"));
"WorkGroupUniformLoad"
}
S::RayQuery { query, ref fun } => {
self.dependencies.push((id, query, "query"));
match *fun {
Expand Down Expand Up @@ -568,6 +573,7 @@ fn write_function_expressions(
}
E::CallResult(_function) => ("CallResult".into(), 4),
E::AtomicResult { .. } => ("AtomicResult".into(), 4),
E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4),
E::ArrayLength(expr) => {
edges.insert("", expr);
("ArrayLength".into(), 7)
Expand Down
27 changes: 22 additions & 5 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,7 @@ impl<'a, W: Write> Writer<'a, W> {

if let Some(name) = expr_name {
write!(self.out, "{level}")?;
self.write_named_expr(handle, name, ctx)?;
self.write_named_expr(handle, name, handle, ctx)?;
}
}
}
Expand Down Expand Up @@ -2125,6 +2125,19 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_expr(value, ctx)?;
writeln!(self.out, ";")?
}
Statement::WorkGroupUniformLoad { pointer, result } => {
// GLSL doesn't have pointers, which means that this backend needs to ensure that
// the actual "loading" is happening between the two barriers.
// This is done in `Emit` by never emitting a variable name for pointer variables
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;

let result_name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "{level}")?;
// Expressions cannot have side effects, so just writing the expression here is fine.
self.write_named_expr(pointer, result_name, result, ctx)?;

self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
// Stores a value into an image.
Statement::ImageStore {
image,
Expand Down Expand Up @@ -3282,7 +3295,8 @@ impl<'a, W: Write> Writer<'a, W> {
// These expressions never show up in `Emit`.
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => unreachable!(),
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down Expand Up @@ -3718,9 +3732,12 @@ impl<'a, W: Write> Writer<'a, W> {
&mut self,
handle: Handle<crate::Expression>,
name: String,
// The expression which is being named.
// Generally, this is the same as handle, except in WorkGroupUniformLoad
named: Handle<crate::Expression>,
ctx: &back::FunctionCtx,
) -> BackendResult {
match ctx.info[handle].ty {
match ctx.info[named].ty {
proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
TypeInner::Struct { .. } => {
let ty_name = &self.names[&NameKey::Type(ty_handle)];
Expand All @@ -3735,7 +3752,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

let base_ty_res = &ctx.info[handle].ty;
let base_ty_res = &ctx.info[named].ty;
let resolved = base_ty_res.inner_with(&self.module.types);

write!(self.out, " {name}")?;
Expand All @@ -3745,7 +3762,7 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, " = ")?;
self.write_expr(handle, ctx)?;
writeln!(self.out, ";")?;
self.named_expressions.insert(handle, name);
self.named_expressions.insert(named, name);

Ok(())
}
Expand Down
20 changes: 16 additions & 4 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

if let Some(name) = expr_name {
write!(self.out, "{level}")?;
self.write_named_expr(module, handle, name, func_ctx)?;
self.write_named_expr(module, handle, name, handle, func_ctx)?;
}
}
}
Expand Down Expand Up @@ -1903,6 +1903,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ", {res_name});")?;
self.named_expressions.insert(result, res_name);
}
Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
write!(self.out, "{level}")?;
let name = format!("_expr{}", result.index());
self.write_named_expr(module, pointer, name, result, func_ctx)?;

self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
Statement::Switch {
selector,
ref cases,
Expand Down Expand Up @@ -2927,6 +2935,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::RayQueryProceedResult => {}
}

Expand Down Expand Up @@ -3017,9 +3026,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
module: &Module,
handle: Handle<crate::Expression>,
name: String,
// The expression which is being named.
// Generally, this is the same as handle, except in WorkGroupUniformLoad
named: Handle<crate::Expression>,
ctx: &back::FunctionCtx,
) -> BackendResult {
match ctx.info[handle].ty {
match ctx.info[named].ty {
proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
TypeInner::Struct { .. } => {
let ty_name = &self.names[&NameKey::Type(ty_handle)];
Expand All @@ -3034,7 +3046,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

let base_ty_res = &ctx.info[handle].ty;
let base_ty_res = &ctx.info[named].ty;
let resolved = base_ty_res.inner_with(&module.types);

write!(self.out, " {name}")?;
Expand All @@ -3045,7 +3057,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, " = ")?;
self.write_expr(module, handle, ctx)?;
writeln!(self.out, ";")?;
self.named_expressions.insert(handle, name);
self.named_expressions.insert(named, name);

Ok(())
}
Expand Down
13 changes: 13 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,7 @@ impl<W: Write> Writer<W> {
// has to be a named expression
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
Expand Down Expand Up @@ -2815,6 +2816,18 @@ impl<W: Write> Writer<W> {
// done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;

write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.put_load(pointer, &context.expression, true)?;
self.named_expressions.insert(result, name);

writeln!(self.out, ";")?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
crate::Statement::RayQuery { query, ref fun } => {
match *fun {
crate::RayQueryFunction::Initialize {
Expand Down
41 changes: 41 additions & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,7 @@ impl<'w> BlockContext<'w> {
crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
crate::Expression::As {
expr,
Expand Down Expand Up @@ -2220,6 +2221,46 @@ impl<'w> BlockContext<'w> {

block.body.push(instruction);
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
// Embed the body of
match self.write_expression_pointer(pointer, &mut block, None)? {
ExpressionPointer::Ready { pointer_id } => {
let id = self.gen_id();
block.body.push(Instruction::load(
result_type_id,
id,
pointer_id,
None,
));
self.cached[result] = id;
}
ExpressionPointer::Conditional { condition, access } => {
self.cached[result] = self.write_conditional_indexed_load(
result_type_id,
condition,
&mut block,
move |id_gen, block| {
// The in-bounds path. Perform the access and the load.
let pointer_id = access.result_id.unwrap();
let value_id = id_gen.next();
block.body.push(access);
block.body.push(Instruction::load(
result_type_id,
value_id,
pointer_id,
None,
));
value_id
},
)
}
}
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
}
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
Expand Down
13 changes: 12 additions & 1 deletion src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,16 @@ impl<W: Write> Writer<W> {
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ");")?
}
Statement::WorkGroupUniformLoad { pointer, result } => {
write!(self.out, "{level}")?;
// TODO: Obey named expressions here.
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
write!(self.out, "workgroupUniformLoad(")?;
self.write_expr(module, pointer, func_ctx)?;
writeln!(self.out, ");")?;
}
Statement::ImageStore {
image,
coordinate,
Expand Down Expand Up @@ -1618,7 +1628,8 @@ impl<W: Write> Writer<W> {
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => {}
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => {}
}

Ok(())
Expand Down
1 change: 1 addition & 0 deletions src/front/glsl/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ impl<'a> ConstantSolver<'a> {
Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
Expression::CallResult { .. } => Err(ConstantSolvingError::Call),
Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic),
Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
Expand Down
1 change: 1 addition & 0 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3703,6 +3703,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
}
}
S::WorkGroupUniformLoad { .. } => unreachable!(),
}
i += 1;
}
Expand Down
6 changes: 6 additions & 0 deletions src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ pub enum Error<'a> {
found: u32,
},
FunctionReturnsVoid(Span),
InvalidWorkGroupUniformLoad(Span),
Other,
}

Expand Down Expand Up @@ -680,6 +681,11 @@ impl<'a> Error<'a> {
"perhaps you meant to call the function in a separate statement?".into(),
],
},
Error::InvalidWorkGroupUniformLoad(span) => ParseError {
message: "incorrect type passed to workgroupUniformLoad".into(),
labels: vec![(span, "".into())],
notes: vec!["passed type must be a workgroup pointer".into()],
},
Error::Other => ParseError {
message: "other error".to_string(),
labels: vec![],
Expand Down
30 changes: 30 additions & 0 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,36 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
args.finish()?;

let pointer = self.expression(expr, ctx.reborrow())?;
ctx.grow_types(pointer)?;
let result_ty = match *ctx.resolved_inner(pointer) {
crate::TypeInner::Pointer {
base,
space: crate::AddressSpace::WorkGroup,
} => base,
ref other => {
log::error!("Type {other:?} passed to workgroupUniformLoad");
let span = ctx.ast_expressions.get_span(expr);
return Err(Error::InvalidWorkGroupUniformLoad(span));
}
};
let result = ctx.interrupt_emitter(
crate::Expression::WorkGroupUniformLoadResult { ty: result_ty },
span,
);
ctx.block.push(
crate::Statement::WorkGroupUniformLoad { pointer, result },
span,
);

ctx.grow_types(pointer)?;
return Ok(Some(result));
}
"textureStore" => {
let mut args = ctx.prepare_args(arguments, 3, span);

Expand Down
22 changes: 22 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,13 @@ pub enum Expression {
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult { ty: Handle<Type>, comparison: bool },
/// Result of a [`WorkGroupUniformLoad`] statement.
///
/// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad
WorkGroupUniformLoadResult {
/// The type of the result
ty: Handle<Type>,
},
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
///
Expand Down Expand Up @@ -1701,6 +1708,21 @@ pub enum Statement {
/// [`AtomicResult`]: crate::Expression::AtomicResult
result: Handle<Expression>,
},
/// Load uniformly from a uniform pointer in the workgroup address space.
///
/// Corresponds to the [`workgroupUniformLoad`](https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin)
/// built-in function of wgsl, and has the same barrier semantics
WorkGroupUniformLoad {
/// This must be of type [`Pointer`] in the [`WorkGroup`] address space
///
/// [`Pointer`]: TypeInner::Pointer
/// [`WorkGroup`]: AddressSpace::WorkGroup
pointer: Handle<Expression>,
/// The [`WorkGroupUniformLoadResult`] expression representing this load's result.
///
/// [`WorkGroupUniformLoadResult`]: Expression::WorkGroupUniformLoadResult
result: Handle<Expression>,
},
/// Calls a function.
///
/// If the `result` is `Some`, the corresponding expression has to be
Expand Down
1 change: 1 addition & 0 deletions src/proc/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::Call { .. }
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
Expand Down
1 change: 1 addition & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
Expand Down
Loading