Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wgsl-in] Handle all(bool) and any(bool) #2445

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let argument = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;

crate::Expression::Relational { fun, argument }
// Check for no-op all(bool) and any(bool):
let argument_unmodified = matches!(
fun,
crate::RelationalFunction::All | crate::RelationalFunction::Any
) && {
ctx.grow_types(argument)?;
matches!(
ctx.resolved_inner(argument),
&crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
..
}
)
};

if argument_unmodified {
return Ok(Some(argument));
} else {
crate::Expression::Relational { fun, argument }
}
} else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = self.expression(args.next()?, ctx.reborrow())?;
Expand Down
8 changes: 8 additions & 0 deletions tests/in/standard.wgsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
// Standard functions.

fn test_any_and_all_for_bool() -> bool {
let a = any(true);
return all(a);
}


@fragment
fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
var x = dpdxCoarse(foo);
Expand All @@ -14,5 +20,7 @@ fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
y = dpdy(foo);
z = fwidth(foo);

let a = test_any_and_all_for_bool();

return (x + y) * z;
}
13 changes: 9 additions & 4 deletions tests/out/glsl/standard.derivatives.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ precision highp int;

layout(location = 0) out vec4 _fs2p_location0;

bool test_any_and_all_for_bool() {
return true;
}

void main() {
vec4 foo = gl_FragCoord;
vec4 x = vec4(0.0);
Expand All @@ -28,10 +32,11 @@ void main() {
y = _e11;
vec4 _e12 = fwidth(foo);
z = _e12;
vec4 _e13 = x;
vec4 _e14 = y;
vec4 _e16 = z;
_fs2p_location0 = ((_e13 + _e14) * _e16);
bool _e13 = test_any_and_all_for_bool();
vec4 _e14 = x;
vec4 _e15 = y;
vec4 _e17 = z;
_fs2p_location0 = ((_e14 + _e15) * _e17);
return;
}

14 changes: 10 additions & 4 deletions tests/out/hlsl/standard.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ struct FragmentInput_derivatives {
float4 foo_1 : SV_Position;
};

bool test_any_and_all_for_bool()
{
return true;
}

float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Target0
{
float4 foo = fragmentinput_derivatives.foo_1;
Expand All @@ -27,8 +32,9 @@ float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Tar
y = _expr11;
float4 _expr12 = fwidth(foo);
z = _expr12;
float4 _expr13 = x;
float4 _expr14 = y;
float4 _expr16 = z;
return ((_expr13 + _expr14) * _expr16);
const bool _e13 = test_any_and_all_for_bool();
float4 _expr14 = x;
float4 _expr15 = y;
float4 _expr17 = z;
return ((_expr14 + _expr15) * _expr17);
}
14 changes: 10 additions & 4 deletions tests/out/msl/standard.msl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
using metal::uint;


bool test_any_and_all_for_bool(
) {
return true;
}

struct derivativesInput {
};
struct derivativesOutput {
Expand Down Expand Up @@ -34,8 +39,9 @@ fragment derivativesOutput derivatives(
y = _e11;
metal::float4 _e12 = metal::fwidth(foo);
z = _e12;
metal::float4 _e13 = x;
metal::float4 _e14 = y;
metal::float4 _e16 = z;
return derivativesOutput { (_e13 + _e14) * _e16 };
bool _e13 = test_any_and_all_for_bool();
metal::float4 _e14 = x;
metal::float4 _e15 = y;
metal::float4 _e17 = z;
return derivativesOutput { (_e14 + _e15) * _e17 };
}
98 changes: 54 additions & 44 deletions tests/out/spv/standard.spvasm
Original file line number Diff line number Diff line change
@@ -1,56 +1,66 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 33
; Bound: 40
OpCapability Shader
OpCapability DerivativeControl
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %16 "derivatives" %11 %14
OpExecutionMode %16 OriginUpperLeft
OpDecorate %11 BuiltIn FragCoord
OpDecorate %14 Location 0
OpEntryPoint Fragment %22 "derivatives" %17 %20
OpExecutionMode %22 OriginUpperLeft
OpDecorate %17 BuiltIn FragCoord
OpDecorate %20 Location 0
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 4
%6 = OpTypePointer Function %3
%7 = OpConstantNull %3
%12 = OpTypePointer Input %3
%11 = OpVariable %12 Input
%15 = OpTypePointer Output %3
%14 = OpVariable %15 Output
%17 = OpTypeFunction %2
%16 = OpFunction %2 None %17
%3 = OpTypeBool
%5 = OpTypeFloat 32
%4 = OpTypeVector %5 4
%8 = OpTypeFunction %3
%9 = OpConstantTrue %3
%12 = OpTypePointer Function %4
%13 = OpConstantNull %4
%18 = OpTypePointer Input %4
%17 = OpVariable %18 Input
%21 = OpTypePointer Output %4
%20 = OpVariable %21 Output
%23 = OpTypeFunction %2
%7 = OpFunction %3 None %8
%6 = OpLabel
OpBranch %10
%10 = OpLabel
%5 = OpVariable %6 Function %7
%8 = OpVariable %6 Function %7
%9 = OpVariable %6 Function %7
%13 = OpLoad %3 %11
OpBranch %18
%18 = OpLabel
%19 = OpDPdxCoarse %3 %13
OpStore %5 %19
%20 = OpDPdyCoarse %3 %13
OpStore %8 %20
%21 = OpFwidthCoarse %3 %13
OpStore %9 %21
%22 = OpDPdxFine %3 %13
OpStore %5 %22
%23 = OpDPdyFine %3 %13
OpStore %8 %23
%24 = OpFwidthFine %3 %13
OpStore %9 %24
%25 = OpDPdx %3 %13
OpStore %5 %25
%26 = OpDPdy %3 %13
OpStore %8 %26
%27 = OpFwidth %3 %13
OpStore %9 %27
%28 = OpLoad %3 %5
%29 = OpLoad %3 %8
%30 = OpFAdd %3 %28 %29
%31 = OpLoad %3 %9
%32 = OpFMul %3 %30 %31
OpReturnValue %9
OpFunctionEnd
%22 = OpFunction %2 None %23
%16 = OpLabel
%11 = OpVariable %12 Function %13
%14 = OpVariable %12 Function %13
%15 = OpVariable %12 Function %13
%19 = OpLoad %4 %17
OpBranch %24
%24 = OpLabel
%25 = OpDPdxCoarse %4 %19
OpStore %11 %25
%26 = OpDPdyCoarse %4 %19
OpStore %14 %26
%27 = OpFwidthCoarse %4 %19
OpStore %15 %27
%28 = OpDPdxFine %4 %19
OpStore %11 %28
%29 = OpDPdyFine %4 %19
OpStore %14 %29
%30 = OpFwidthFine %4 %19
OpStore %15 %30
%31 = OpDPdx %4 %19
OpStore %11 %31
%32 = OpDPdy %4 %19
OpStore %14 %32
%33 = OpFwidth %4 %19
OpStore %15 %33
%34 = OpFunctionCall %3 %7
%35 = OpLoad %4 %11
%36 = OpLoad %4 %14
%37 = OpFAdd %4 %35 %36
%38 = OpLoad %4 %15
%39 = OpFMul %4 %37 %38
OpStore %20 %39
OpReturn
OpFunctionEnd
13 changes: 9 additions & 4 deletions tests/out/wgsl/standard.wgsl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
fn test_any_and_all_for_bool() -> bool {
return true;
}

@fragment
fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
var x: vec4<f32>;
Expand All @@ -22,8 +26,9 @@ fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
y = _e11;
let _e12 = fwidth(foo);
z = _e12;
let _e13 = x;
let _e14 = y;
let _e16 = z;
return ((_e13 + _e14) * _e16);
let _e13 = test_any_and_all_for_bool();
let _e14 = x;
let _e15 = y;
let _e17 = z;
return ((_e14 + _e15) * _e17);
}
Loading