Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Expression::ZeroValue. #2332

Merged
merged 2 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ fn write_function_expressions(
for (handle, expression) in fun.expressions.iter() {
use crate::Expression as E;
let (label, color_id) = match *expression {
E::Constant(_) => ("Constant".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Access { base, index } => {
edges.insert("base", base);
edges.insert("index", index);
Expand All @@ -406,7 +408,6 @@ fn write_function_expressions(
edges.insert("base", base);
(format!("AccessIndex[{index}]").into(), 1)
}
E::Constant(_) => ("Constant".into(), 2),
E::Splat { size, value } => {
edges.insert("value", value);
(format!("Splat{size:?}").into(), 3)
Expand Down
3 changes: 3 additions & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,9 @@ impl<'a, W: Write> Writer<'a, W> {
}
// Constants are delegated to `write_constant`
Expression::Constant(constant) => self.write_constant(constant)?,
Expression::ZeroValue(ty) => {
self.write_zero_init_value(ty)?;
}
// `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
Expression::Splat { size: _, value } => {
let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
Expand Down
1 change: 1 addition & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

match *expression {
Expression::Constant(constant) => self.write_constant(module, constant)?,
Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
Expression::Compose { ty, ref components } => {
match module.types[ty].inner {
TypeInner::Struct { .. } | TypeInner::Array { .. } => {
Expand Down
11 changes: 11 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,17 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{coco}")?;
}
crate::Expression::ZeroValue(ty) => {
let ty_name = TypeContext {
handle: ty,
module: context.module,
names: &self.names,
access: crate::StorageAccess::empty(),
binding: None,
first_time: false,
};
write!(self.out, "{ty_name} {{}}")?;
}
crate::Expression::Splat { size, value } => {
let scalar_kind = match *context.resolve_type(value) {
crate::TypeInner::Scalar { kind, .. } => kind,
Expand Down
1 change: 1 addition & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ impl<'w> BlockContext<'w> {
self.writer.global_variables[handle.index()].access_id
}
crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
crate::Expression::ZeroValue(_) => self.writer.write_constant_null(result_type_id),
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
let components = [value_id; 4];
Expand Down
4 changes: 4 additions & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,10 @@ impl<W: Write> Writer<W> {
// subscripting.
match *expression {
Expression::Constant(constant) => self.write_constant(module, constant)?,
Expression::ZeroValue(ty) => {
self.write_type(module, ty)?;
write!(self.out, "()")?;
}
Expression::Compose { ty, ref components } => {
self.write_type(module, ty)?;
write!(self.out, "(")?;
Expand Down
101 changes: 101 additions & 0 deletions src/front/glsl/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl<'a> ConstantSolver<'a> {
let span = self.expressions.get_span(expr);
match self.expressions[expr] {
Expression::Constant(constant) => Ok(constant),
Expression::ZeroValue(ty) => self.register_zero_constant(ty, span),
Expression::AccessIndex { base, index } => self.access(base, index as usize),
Expression::Access { base, index } => {
let index = self.solve(index)?;
Expand Down Expand Up @@ -547,6 +548,106 @@ impl<'a> ConstantSolver<'a> {
Ok(self.register_constant(inner, span))
}

fn register_zero_constant(
&mut self,
ty: Handle<Type>,
span: crate::Span,
) -> Result<Handle<Constant>, ConstantSolvingError> {
let inner = match self.types[ty].inner {
TypeInner::Scalar { kind, width } => {
let value = match kind {
ScalarKind::Sint => ScalarValue::Sint(0),
ScalarKind::Uint => ScalarValue::Uint(0),
ScalarKind::Float => ScalarValue::Float(1.0),
ScalarKind::Bool => ScalarValue::Bool(false),
};
ConstantInner::Scalar { width, value }
}
TypeInner::Vector { size, kind, width } => {
let element_type = self.types.insert(
Type {
name: None,
inner: TypeInner::Scalar { kind, width },
},
span,
);
let element = self.register_zero_constant(element_type, span)?;
let components = std::iter::repeat(element)
.take(size as u8 as usize)
.collect();
ConstantInner::Composite { ty, components }
}
TypeInner::Matrix {
columns,
rows,
width,
} => {
let column_type = self.types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: rows,
kind: ScalarKind::Float,
width,
},
},
span,
);
let column = self.register_zero_constant(column_type, span)?;
let components = std::iter::repeat(column)
.take(columns as u8 as usize)
.collect();
ConstantInner::Composite { ty, components }
}
TypeInner::Array { base, size, .. } => {
let length = match size {
crate::ArraySize::Constant(handle) => match self.constants[handle].inner {
ConstantInner::Scalar {
value: ScalarValue::Uint(length),
..
} => length as usize,
ConstantInner::Scalar {
value: ScalarValue::Sint(length),
..
} => {
if length < 0 {
return Err(ConstantSolvingError::InvalidArrayLengthArg);
}
length as usize
}
_ => return Err(ConstantSolvingError::InvalidArrayLengthArg),
},
crate::ArraySize::Dynamic => {
return Err(ConstantSolvingError::ArrayLengthDynamic)
}
};
let element = self.register_zero_constant(base, span)?;
let components = std::iter::repeat(element)
.take(length as u8 as usize)
.collect();
ConstantInner::Composite { ty, components }
}
TypeInner::Struct { ref members, .. } => {
// Make a copy of the member types, for the borrow checker.
let types: Vec<Handle<Type>> = members.iter().map(|member| member.ty).collect();
let mut components = vec![];
for member_ty in types {
let value = self.register_zero_constant(member_ty, span)?;
components.push(value);
}
ConstantInner::Composite { ty, components }
}
ref inner => {
return Err(ConstantSolvingError::NotImplemented(format!(
"zero-value construction for types: {:?}",
inner
)));
}
};

Ok(self.register_constant(inner, span))
}

fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant> {
self.constants.fetch_or_append(
Constant {
Expand Down
9 changes: 6 additions & 3 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,14 +1324,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::NoLine => inst.expect(1)?,
Op::Undef => {
inst.expect(3)?;
let (type_id, id, handle) =
self.parse_null_constant(inst, ctx.type_arena, ctx.const_arena)?;
let type_id = self.next()?;
let id = self.next()?;
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;

self.lookup_expression.insert(
id,
LookupExpression {
handle: ctx
.expressions
.append(crate::Expression::Constant(handle), span),
.append(crate::Expression::ZeroValue(ty), span),
type_id,
block_id,
},
Expand Down
7 changes: 1 addition & 6 deletions src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
_ => return Err(Error::TypeNotInferrable(ty_span)),
};

return match ctx.create_zero_value_constant(ty) {
Some(constant) => {
Ok(ctx.interrupt_emitter(crate::Expression::Constant(constant), span))
}
None => Err(Error::TypeNotConstructible(ty_span)),
};
return Ok(ctx.interrupt_emitter(crate::Expression::ZeroValue(ty), span));
}

// Scalar constructor & conversion (scalar -> scalar)
Expand Down
77 changes: 0 additions & 77 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,83 +342,6 @@ impl<'a> ExpressionContext<'a, '_, '_> {
}
}

/// Creates a zero value constant of type `ty`
///
/// Returns `None` if the given `ty` is not a constructible type
fn create_zero_value_constant(
&mut self,
ty: Handle<crate::Type>,
) -> Option<Handle<crate::Constant>> {
let inner = match self.module.types[ty].inner {
crate::TypeInner::Scalar { kind, width } => {
let value = match kind {
crate::ScalarKind::Sint => crate::ScalarValue::Sint(0),
crate::ScalarKind::Uint => crate::ScalarValue::Uint(0),
crate::ScalarKind::Float => crate::ScalarValue::Float(0.),
crate::ScalarKind::Bool => crate::ScalarValue::Bool(false),
};
crate::ConstantInner::Scalar { width, value }
}
crate::TypeInner::Vector { size, kind, width } => {
let scalar_ty = self.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
let component = self.create_zero_value_constant(scalar_ty)?;
crate::ConstantInner::Composite {
ty,
components: (0..size as u8).map(|_| component).collect(),
}
}
crate::TypeInner::Matrix {
columns,
rows,
width,
} => {
let vec_ty = self.ensure_type_exists(crate::TypeInner::Vector {
width,
kind: crate::ScalarKind::Float,
size: rows,
});
let component = self.create_zero_value_constant(vec_ty)?;
crate::ConstantInner::Composite {
ty,
components: (0..columns as u8).map(|_| component).collect(),
}
}
crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(size),
..
} => {
let size = self.module.constants[size].to_array_length()?;
let component = self.create_zero_value_constant(base)?;
crate::ConstantInner::Composite {
ty,
components: (0..size).map(|_| component).collect(),
}
}
crate::TypeInner::Struct { ref members, .. } => {
let members = members.clone();
crate::ConstantInner::Composite {
ty,
components: members
.iter()
.map(|member| self.create_zero_value_constant(member.ty))
.collect::<Option<_>>()?,
}
}
_ => return None,
};

let constant = self.module.constants.fetch_or_append(
crate::Constant {
name: None,
specialization: None,
inner,
},
Span::UNDEFINED,
);
Some(constant)
}

fn format_typeinner(&self, inner: &crate::TypeInner) -> String {
inner.to_wgsl(&self.module.types, &self.module.constants)
}
Expand Down
10 changes: 7 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ of `Statement`s and other `Expression`s.

Naga's rules for when `Expression`s are evaluated are as follows:

- [`Constant`](Expression::Constant) expressions are considered to be
- [`Constant`] and [`ZeroValue`] expressions are considered to be
implicitly evaluated before execution begins.

- [`FunctionArgument`] and [`LocalVariable`] expressions are considered
Expand Down Expand Up @@ -174,6 +174,7 @@ tree.
[`RayQueryProceedResult`]: Expression::RayQueryProceedResult
[`CallResult`]: Expression::CallResult
[`Constant`]: Expression::Constant
[`ZeroValue`]: Expression::ZeroValue
[`Derivative`]: Expression::Derivative
[`FunctionArgument`]: Expression::FunctionArgument
[`GlobalVariable`]: Expression::GlobalVariable
Expand Down Expand Up @@ -1189,6 +1190,11 @@ bitflags::bitflags! {
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum Expression {
/// Constant value.
Constant(Handle<Constant>),
/// Zero value of a type.
ZeroValue(Handle<Type>),

/// Array access with a computed index.
///
/// ## Typing rules
Expand Down Expand Up @@ -1247,8 +1253,6 @@ pub enum Expression {
base: Handle<Expression>,
index: u32,
},
/// Constant value.
Constant(Handle<Constant>),
/// Splat scalar into a vector.
Splat {
size: VectorSize,
Expand Down
1 change: 1 addition & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ impl crate::Expression {
pub const fn needs_pre_emit(&self) -> bool {
match *self {
Self::Constant(_)
| Self::ZeroValue(_)
| Self::FunctionArgument(_)
| Self::GlobalVariable(_)
| Self::LocalVariable(_) => true,
Expand Down
1 change: 1 addition & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ impl<'a> ResolveContext<'a> {
}
crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty),
},
crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
Ti::Scalar { kind, width } => {
TypeResolution::Value(Ti::Vector { size, kind, width })
Expand Down
1 change: 1 addition & 0 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ impl FunctionInfo {
},
// always uniform
E::Constant(_) => Uniformity::new(),
E::ZeroValue(_) => Uniformity::new(),
E::Splat { size: _, value } => Uniformity {
non_uniform_result: self.add_ref(value),
requirements: UniformityRequirements::empty(),
Expand Down
1 change: 1 addition & 0 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ impl super::Validator {
ShaderStages::all()
}
E::Constant(_handle) => ShaderStages::all(),
E::ZeroValue(_type) => ShaderStages::all(),
E::Splat { size: _, value } => match resolver[value] {
Ti::Scalar { .. } => ShaderStages::all(),
ref other => {
Expand Down
3 changes: 3 additions & 0 deletions src/valid/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ impl super::Validator {
crate::Expression::Constant(constant) => {
validate_constant(constant)?;
}
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
crate::Expression::Splat { value, .. } => {
handle.check_dep(value)?;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/out/glsl/access.foo_frag.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void main() {
_group_0_binding_0_fs._matrix = mat4x3(vec3(0.0), vec3(1.0), vec3(2.0), vec3(3.0));
_group_0_binding_0_fs.arr = uvec2[2](uvec2(0u), uvec2(1u));
_group_0_binding_0_fs.data[1].value = 1;
_group_0_binding_2_fs = ivec2(0, 0);
_group_0_binding_2_fs = ivec2(0);
_fs2p_location0 = vec4(0.0);
return;
}
Expand Down
Loading