Skip to content

Commit

Permalink
feat: Add 32-bit floating-point atomics
Browse files Browse the repository at this point in the history
* Current supported platforms: Metal
* Platforms to support in the future: Vulkan

Related issues or PRs:

* gfx-rs#1020
  • Loading branch information
AsherJingkongChen committed Sep 7, 2024
1 parent 9b36a3e commit bc058fd
Show file tree
Hide file tree
Showing 16 changed files with 318 additions and 4 deletions.
11 changes: 11 additions & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1934,12 +1934,15 @@ pub enum Statement {
/// If [`SHADER_INT64_ATOMIC_MIN_MAX`] or [`SHADER_INT64_ATOMIC_ALL_OPS`] are
/// enabled, this may also be [`I64`] or [`U64`].
///
/// If [`SHADER_FLT32_ATOMIC`] is enabled, this may be [`F32`].
///
/// [`Pointer`]: TypeInner::Pointer
/// [`Atomic`]: TypeInner::Atomic
/// [`I32`]: Scalar::I32
/// [`U32`]: Scalar::U32
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLT32_ATOMIC
/// [`I64`]: Scalar::I64
/// [`U64`]: Scalar::U64
pointer: Handle<Expression>,
Expand All @@ -1957,9 +1960,16 @@ pub enum Statement {
/// - If neither of those capabilities are present, then 64-bit scalar
/// atomics are not allowed.
///
/// If [`pointer`] refers to a 32-bit floating-point atomic value, then:
///
/// - The [`SHADER_FLT32_ATOMIC`] capability allows allows
/// [`AtomicFunction::Add`], [`AtomicFunction::Subtract`] and
/// [`AtomicFunction::Exchange`] here.
///
/// [`pointer`]: Statement::Atomic::pointer
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLT32_ATOMIC
fun: AtomicFunction,

/// Value to use in the function.
Expand All @@ -1986,6 +1996,7 @@ pub enum Statement {
/// [`Exchange { compare: None }`]: AtomicFunction::Exchange
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
/// [`SHADER_FLT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLT32_ATOMIC
result: Option<Handle<Expression>>,
},
/// Load uniformly from a uniform pointer in the workgroup address space.
Expand Down
45 changes: 44 additions & 1 deletion naga/src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ pub enum CallError {
pub enum AtomicError {
#[error("Pointer {0:?} to atomic is invalid.")]
InvalidPointer(Handle<crate::Expression>),
#[error("Address space {0:?} does not support 64bit atomics.")]
#[error("Address space {0:?} is not supported.")]
InvalidAddressSpace(crate::AddressSpace),
#[error("Function {0:?} is not supported.")]
InvalidFunction(crate::AtomicFunction),
#[error("Operand {0:?} has invalid type.")]
InvalidOperand(Handle<crate::Expression>),
#[error("Result expression {0:?} is not an `AtomicResult` expression")]
Expand Down Expand Up @@ -446,6 +448,47 @@ impl super::Validator {
}
}

// Check for the special restrictions on 32-bit floating-point atomic operations.
//
// We don't need to consider other widths here: this function has already checked
// that `pointer`'s type is an `Atomic`, and `validate_type` has already checked
// that that `Atomic` type has a permitted scalar width.
if let crate::ScalarKind::Float = pointer_scalar.kind {
// `Capabilities::SHADER_FLT32_ATOMIC` enables 32-bit floating-point
// atomic operations including `Add`, `Subtract`, and `Exchange`
// in storage address space.
if !matches!(
*fun,
crate::AtomicFunction::Add
| crate::AtomicFunction::Subtract
| crate::AtomicFunction::Exchange { compare: _ }
) {
log::error!("Float32 atomic operation {:?} is not supported", fun);
return Err(AtomicError::InvalidFunction(*fun)
.with_span_handle(value, context.expressions)
.into_other());
}
if !self
.capabilities
.contains(super::Capabilities::SHADER_FLT32_ATOMIC)
{
log::error!("Float32 atomic operations are not supported");
return Err(AtomicError::MissingCapability(
super::Capabilities::SHADER_FLT32_ATOMIC,
)
.with_span_handle(value, context.expressions)
.into_other());
}
if !matches!(pointer_space, crate::AddressSpace::Storage { .. }) {
log::error!(
"Float32 atomic operations are only supported in storage address space"
);
return Err(AtomicError::InvalidAddressSpace(pointer_space)
.with_span_handle(value, context.expressions)
.into_other());
}
}

// The result expression must be appropriate to the operation.
match result {
Some(result) => {
Expand Down
11 changes: 11 additions & 0 deletions naga/src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ bitflags::bitflags! {
const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000;
/// Support for all atomic operations on 64-bit integers.
const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000;
/// Support for [`AtomicFunction::Add`], [`AtomicFunction::Sub`]
/// and [`AtomicFunction::Exchange`] on 32-bit floating-point numbers
/// in the [`Storage`] address space.
///
/// [`AtomicFunction::Add`]: crate::AtomicFunction::Add
/// [`AtomicFunction::Sub`]: crate::AtomicFunction::Sub
/// [`AtomicFunction::Exchange`]: crate::AtomicFunction::Exchange
/// [`Storage`]: crate::AddressSpace::Storage
const SHADER_FLT32_ATOMIC = 0x200000;
}
}

Expand Down Expand Up @@ -601,6 +610,8 @@ impl Validator {
.into_boxed_slice(),
};

// TODO: Error

for (handle, ty) in module.types.iter() {
let ty_info = self
.validate_type(handle, module.to_ctx())
Expand Down
15 changes: 14 additions & 1 deletion naga/src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ impl super::Validator {
Ti::Atomic(crate::Scalar { kind, width }) => {
match kind {
crate::ScalarKind::Bool
| crate::ScalarKind::Float
| crate::ScalarKind::AbstractInt
| crate::ScalarKind::AbstractFloat => {
return Err(TypeError::InvalidAtomicWidth(kind, width))
Expand All @@ -381,6 +380,20 @@ impl super::Validator {
return Err(TypeError::InvalidAtomicWidth(kind, width));
}
}
crate::ScalarKind::Float => {
if width == 4 {
if !self
.capabilities
.intersects(Capabilities::SHADER_FLT32_ATOMIC)
{
return Err(TypeError::MissingCapability(
Capabilities::SHADER_FLT32_ATOMIC,
));
}
} else {
return Err(TypeError::InvalidAtomicWidth(kind, width));
}
}
};
TypeInfo::new(
TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
Expand Down
11 changes: 11 additions & 0 deletions naga/tests/in/atomicOps-flt32.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(
god_mode: true,
msl: (
lang_version: (3, 0),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: false,
),
)
54 changes: 54 additions & 0 deletions naga/tests/in/atomicOps-flt32.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
struct Struct {
atomic_scalar: atomic<f32>,
atomic_arr: array<atomic<f32>, 2>,
}

@group(0) @binding(0)
var<storage, read_write> storage_atomic_scalar: atomic<f32>;
@group(0) @binding(1)
var<storage, read_write> storage_atomic_arr: array<atomic<f32>, 2>;
@group(0) @binding(2)
var<storage, read_write> storage_struct: Struct;

@compute
@workgroup_size(2)
fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
atomicStore(&storage_atomic_scalar, 1.0);
atomicStore(&storage_atomic_arr[1], 1.0);
atomicStore(&storage_struct.atomic_scalar, 1.0);
atomicStore(&storage_struct.atomic_arr[1], 1.0);

workgroupBarrier();

let l0 = atomicLoad(&storage_atomic_scalar);
let l1 = atomicLoad(&storage_atomic_arr[1]);
let l2 = atomicLoad(&storage_struct.atomic_scalar);
let l3 = atomicLoad(&storage_struct.atomic_arr[1]);

workgroupBarrier();

atomicAdd(&storage_atomic_scalar, 1.0);
atomicAdd(&storage_atomic_arr[1], 1.0);
atomicAdd(&storage_struct.atomic_scalar, 1.0);
atomicAdd(&storage_struct.atomic_arr[1], 1.0);

workgroupBarrier();

atomicSub(&storage_atomic_scalar, 1.0);
atomicSub(&storage_atomic_arr[1], 1.0);
atomicSub(&storage_struct.atomic_scalar, 1.0);
atomicSub(&storage_struct.atomic_arr[1], 1.0);

workgroupBarrier();

atomicExchange(&storage_atomic_scalar, 1.0);
atomicExchange(&storage_atomic_arr[1], 1.0);
atomicExchange(&storage_struct.atomic_scalar, 1.0);
atomicExchange(&storage_struct.atomic_arr[1], 1.0);

// // TODO: https://github.com/gpuweb/gpuweb/issues/2021
// atomicCompareExchangeWeak(&storage_atomic_scalar, 1.0);
// atomicCompareExchangeWeak(&storage_atomic_arr[1], 1.0);
// atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1.0);
// atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1.0);
}
48 changes: 48 additions & 0 deletions naga/tests/out/msl/atomicOps-flt32.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// language: metal3.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;

struct type_1 {
metal::atomic_float inner[2];
};
struct Struct {
metal::atomic_float atomic_scalar;
type_1 atomic_arr;
};

struct cs_mainInput {
};
kernel void cs_main(
metal::uint3 id [[thread_position_in_threadgroup]]
, device metal::atomic_float& storage_atomic_scalar [[user(fake0)]]
, device type_1& storage_atomic_arr [[user(fake0)]]
, device Struct& storage_struct [[user(fake0)]]
) {
metal::atomic_store_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
metal::atomic_store_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::atomic_store_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
metal::atomic_store_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float l0_ = metal::atomic_load_explicit(&storage_atomic_scalar, metal::memory_order_relaxed);
float l1_ = metal::atomic_load_explicit(&storage_atomic_arr.inner[1], metal::memory_order_relaxed);
float l2_ = metal::atomic_load_explicit(&storage_struct.atomic_scalar, metal::memory_order_relaxed);
float l3_ = metal::atomic_load_explicit(&storage_struct.atomic_arr.inner[1], metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float _e27 = metal::atomic_fetch_add_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e31 = metal::atomic_fetch_add_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e35 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e40 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float _e43 = metal::atomic_fetch_sub_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e47 = metal::atomic_fetch_sub_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e51 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e56 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
float _e59 = metal::atomic_exchange_explicit(&storage_atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e63 = metal::atomic_exchange_explicit(&storage_atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
float _e67 = metal::atomic_exchange_explicit(&storage_struct.atomic_scalar, 1.0, metal::memory_order_relaxed);
float _e72 = metal::atomic_exchange_explicit(&storage_struct.atomic_arr.inner[1], 1.0, metal::memory_order_relaxed);
return;
}
40 changes: 40 additions & 0 deletions naga/tests/out/wgsl/atomicOps-flt32.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
struct Struct {
atomic_scalar: atomic<f32>,
atomic_arr: array<atomic<f32>, 2>,
}

@group(0) @binding(0)
var<storage, read_write> storage_atomic_scalar: atomic<f32>;
@group(0) @binding(1)
var<storage, read_write> storage_atomic_arr: array<atomic<f32>, 2>;
@group(0) @binding(2)
var<storage, read_write> storage_struct: Struct;

@compute @workgroup_size(2, 1, 1)
fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
atomicStore((&storage_atomic_scalar), 1f);
atomicStore((&storage_atomic_arr[1]), 1f);
atomicStore((&storage_struct.atomic_scalar), 1f);
atomicStore((&storage_struct.atomic_arr[1]), 1f);
workgroupBarrier();
let l0_ = atomicLoad((&storage_atomic_scalar));
let l1_ = atomicLoad((&storage_atomic_arr[1]));
let l2_ = atomicLoad((&storage_struct.atomic_scalar));
let l3_ = atomicLoad((&storage_struct.atomic_arr[1]));
workgroupBarrier();
let _e27 = atomicAdd((&storage_atomic_scalar), 1f);
let _e31 = atomicAdd((&storage_atomic_arr[1]), 1f);
let _e35 = atomicAdd((&storage_struct.atomic_scalar), 1f);
let _e40 = atomicAdd((&storage_struct.atomic_arr[1]), 1f);
workgroupBarrier();
let _e43 = atomicSub((&storage_atomic_scalar), 1f);
let _e47 = atomicSub((&storage_atomic_arr[1]), 1f);
let _e51 = atomicSub((&storage_struct.atomic_scalar), 1f);
let _e56 = atomicSub((&storage_struct.atomic_arr[1]), 1f);
workgroupBarrier();
let _e59 = atomicExchange((&storage_atomic_scalar), 1f);
let _e63 = atomicExchange((&storage_atomic_arr[1]), 1f);
let _e67 = atomicExchange((&storage_struct.atomic_scalar), 1f);
let _e72 = atomicExchange((&storage_struct.atomic_arr[1]), 1f);
return;
}
1 change: 1 addition & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ fn convert_wgsl() {
"atomicOps-int64-min-max",
Targets::SPIRV | Targets::METAL | Targets::HLSL | Targets::WGSL,
),
("atomicOps-flt32", Targets::METAL | Targets::WGSL),
(
"atomicCompareExchange-int64",
Targets::SPIRV | Targets::WGSL,
Expand Down
40 changes: 40 additions & 0 deletions tests/tests/shader/numeric_builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ fn create_int64_atomic_all_ops_test() -> Vec<ShaderTest> {
tests
}

fn create_flt32_atomic_test() -> Vec<ShaderTest> {
let mut tests = Vec::new();

let test = ShaderTest::new(
"atomicAdd".into(),
"value: f32".into(),
"atomicStore(&output, 0.0); atomicAdd(&output, 1.0); atomicAdd(&output, 1.0);".into(),
&[0.0_f32],
&[2.0_f32],
)
.output_type("atomic<f32>".into());

tests.push(test);

let test = ShaderTest::new(
"atomicSub".into(),
"value: f32".into(),
"atomicStore(&output, 0.0); atomicSub(&output, -1.0); atomicSub(&output, 0.5);".into(),
&[0.0_f32],
&[0.5_f32],
)
.output_type("atomic<f32>".into());

tests.push(test);

tests
}

#[gpu_test]
static FLT32_ATOMIC: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgt::Features::SHADER_FLT32_ATOMIC)
.downlevel_flags(DownlevelFlags::COMPUTE_SHADERS)
.limits(Limits::downlevel_defaults()),
)
.run_async(|ctx| {
shader_input_output_test(ctx, InputStorageType::Storage, create_flt32_atomic_test())
});

#[gpu_test]
static INT64_ATOMIC_ALL_OPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
Expand Down
8 changes: 7 additions & 1 deletion wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,16 @@ mod compat {
}
}

#[derive(Clone, Debug, Error)]
#[error("Unknown reason")]
struct Unknown();

Err(Error::Incompatible {
expected_bgl: expected_bgl.error_ident(),
assigned_bgl: assigned_bgl.error_ident(),
inner: MultiError::new(errors.drain(..)).unwrap(),
inner: MultiError::new(errors.drain(..)).unwrap_or_else(|| {
MultiError::new(core::iter::once(Unknown())).unwrap()
}),
})
}
} else {
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,10 @@ pub fn create_validator(
Caps::SHADER_INT64_ATOMIC_ALL_OPS,
features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
);
caps.set(
Caps::SHADER_FLT32_ATOMIC,
features.contains(wgt::Features::SHADER_FLT32_ATOMIC),
);
caps.set(
Caps::MULTISAMPLED_SHADING,
downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
Expand Down
Loading

0 comments on commit bc058fd

Please sign in to comment.