diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md index a70a9f88c8..092dab9f09 100644 --- a/docs/user-guide/a3-02-reference-capability-atoms.md +++ b/docs/user-guide/a3-02-reference-capability-atoms.md @@ -852,7 +852,7 @@ Compound Capabilities `shadermemorycontrol` > (gfx targets) Capabilities needed to use memory barriers -`waveprefix` +`wave_multi_prefix` > Capabilities needed to use HLSL tier wave operations `bufferreference` diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index ba5c95a0c5..1853a82b64 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -15517,7 +15517,7 @@ uint4 WaveMatch(matrix value) } /// @category wave -[require(cuda_hlsl, waveprefix)] +[require(cuda_hlsl, wave_multi_prefix)] uint WaveMultiPrefixCountBits(bool value, uint4 mask) { __target_switch @@ -15528,190 +15528,366 @@ uint WaveMultiPrefixCountBits(bool value, uint4 mask) } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitAnd(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAnd(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -__generic -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitAnd(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitAnd(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitAnd(expr[i], mask); + return result; } } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitOr(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOr(, _getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveOr($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitOr(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveOr($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitOr(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitOr(expr[i], mask); + return result; } } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitXor(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXor(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveXor($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitXor(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveXor($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitXor(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitXor(expr[i], mask); + return result; } } /// @category wave __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixProduct(T value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProduct(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$T = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$T = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixProduct(vector value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixProduct(matrix value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixProduct(value[i], mask); + return result; } } /// @category wave __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixSum(T value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSum(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$T = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$T = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] +__spirv_version(1.3) vector WaveMultiPrefixSum(vector value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0 )"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixSum(matrix value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixSum(value[i], mask); + return result; } } diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 4f6357779c..3bc54c080a 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -1024,7 +1024,9 @@ alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1; alias shadermemorycontrol = glsl | _spirv_1_0 | _sm_5_0; /// Capabilities needed to use HLSL tier wave operations /// [Compound] -alias waveprefix = _sm_6_5 | _cuda_sm_7_0 | GL_KHR_shader_subgroup_arithmetic; +alias wave_multi_prefix = _sm_6_5 + | _cuda_sm_7_0 + | GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_arithmetic + GL_NV_shader_subgroup_partitioned; /// Capabilities needed to use GLSL buffer-reference's /// [Compound] alias bufferreference = GL_EXT_buffer_reference; diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang new file mode 100644 index 0000000000..69240198e0 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang @@ -0,0 +1,74 @@ +//TEST_CATEGORY(wave, compute) +//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj +//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj + +//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj +//TEST:COMPARE_COMPUTE_EX:-vk -compute -shaderobj +//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(8, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = int(dispatchThreadID.x); + + // Split into two groups. + uint4 mask = 0b00001111; + if (index >= 4) + { + mask = 0b11110000; + } + + // + // WaveMultiPrefixSum. + // Results in hex: [0 1 3 7], [0 10 30 70] + // + uint sumValue = WaveMultiPrefixSum(1 << index, mask); + const uint sumBaseIndex = 0; + outputBuffer[sumBaseIndex + index] = sumValue; + + // + // WaveMultiPrefixProduct. + // Results in hex: [1 1 2 8], [1 10 200 8000] + // + uint productValue = WaveMultiPrefixProduct(1 << index, mask); + const uint productBaseIndex = 8; + outputBuffer[productBaseIndex + index] = productValue; + + // + // WaveMultiPrefixBitAnd. + // This prefix operation starts with all bits set. + // Results in hex: [FFFFFFFF 1 1 1], [FFFFFFFF F F F] + // + uint andBits = 0b1; + if (index >= 4) + { + andBits = 0b1111; + } + uint andValue = WaveMultiPrefixBitAnd(andBits, mask); + const uint andBaseIndex = 16; + outputBuffer[andBaseIndex + index] = andValue; + + // + // WaveMultiPrefixBitOr. + // Results in hex: [0 1 3 7], [0 10 30 70] + // + uint orValue = WaveMultiPrefixBitOr(1 << index, mask); + const uint orBaseIndex = 24; + outputBuffer[orBaseIndex + index] = orValue; + + // + // WaveMultiPrefixBitXor. + // Results in hex: [0 1 3 7], [0 F 0 F] + // + uint xorBits = (1 << index); + if (index >= 4) + { + xorBits = 0b1111; + } + uint xorValue = WaveMultiPrefixBitXor(xorBits, mask); + const uint xorBaseIndex = 32; + outputBuffer[xorBaseIndex + index] = xorValue; +} diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt new file mode 100644 index 0000000000..c80baa5b18 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt @@ -0,0 +1,40 @@ +0 +1 +3 +7 +0 +10 +30 +70 +1 +1 +2 +8 +1 +10 +200 +8000 +FFFFFFFF +1 +1 +1 +FFFFFFFF +F +F +F +0 +1 +3 +7 +0 +10 +30 +70 +0 +1 +3 +7 +0 +F +0 +F diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang index 31dde2af43..99698e4979 100644 --- a/tests/hlsl-intrinsic/wave-multi-prefix.slang +++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang @@ -1,27 +1,146 @@ -//TEST_CATEGORY(wave, compute) -//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj -// We need SM6.5 for these tests -// Disable because version of dxc we are currently using doesn't support SM6.5 -//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj -// Disabled because we don't have GLSL intrinsics for these it seems -//DISABLE_TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj - -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer -RWStructuredBuffer outputBuffer; - -[numthreads(8, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +//TEST:SIMPLE(filecheck=CHECK_SPIRV): -stage compute -entry computeMain -target spirv -DNO_INTEGER_MATRIX +//TEST:SIMPLE(filecheck=CHECK_GLSL): -stage compute -entry computeMain -target glsl -DNO_INTEGER_MATRIX +//TEST:SIMPLE(filecheck=CHECK_CUDA): -stage compute -entry computeMain -target cuda +//TEST:SIMPLE(filecheck=CHECK_HLSL): -stage compute -entry computeMain -target hlsl + +// +// Tests all variants and overloads of WaveMultiPrefix* arithmetic intrinsics. +// + +struct OutputData +{ + int scalarSum; + int scalarProduct; + int scalarBitAnd; + int scalarBitOr; + int scalarBitXor; + int vectorSum; + int vectorProduct; + int vectorBitAnd; + int vectorBitOr; + int vectorBitXor; + int matrixSum; + int matrixProduct; + int matrixBitAnd; + int matrixBitOr; + int matrixBitXor; + float floatScalarSum; + float floatScalarProduct; + float floatVectorSum; + float floatVectorProduct; + float floatMatrixSum; + float floatMatrixProduct; +}; + +RWStructuredBuffer outputBuffer; + +// CHECK_SPIRV: OpCapability GroupNonUniformPartitionedNV +// CHECK_SPIRV: OpExtension "SPV_NV_shader_subgroup_partitioned" +// CHECK_SPIRV: OpGroupNonUniformIAdd{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformIMul{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseAnd{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseOr{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseXor{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformFAdd{{.*}}PartitionedExclusiveScanNV + +// CHECK_GLSL: GL_NV_shader_subgroup_partitioned +// CHECK_GLSL: subgroupPartitionedExclusiveAddNV +// CHECK_GLSL: subgroupPartitionedExclusiveMulNV +// CHECK_GLSL: subgroupPartitionedExclusiveAndNV +// CHECK_GLSL: subgroupPartitionedExclusiveOrNV +// CHECK_GLSL: subgroupPartitionedExclusiveXorNV + +// CHECK_CUDA: _wavePrefixSum +// CHECK_CUDA: _wavePrefixProduct +// CHECK_CUDA: _wavePrefixAnd +// CHECK_CUDA: _wavePrefixOr +// CHECK_CUDA: _wavePrefixXor +// CHECK_CUDA: _wavePrefixSumMultiple +// CHECK_CUDA: _wavePrefixProductMultiple +// CHECK_CUDA: _wavePrefixAndMultiple +// CHECK_CUDA: _wavePrefixOrMultiple +// CHECK_CUDA: _wavePrefixXorMultiple + +// CHECK_HLSL: WaveMultiPrefixSum +// CHECK_HLSL: WaveMultiPrefixProduct +// CHECK_HLSL: WaveMultiPrefixBitAnd +// CHECK_HLSL: WaveMultiPrefixBitOr +// CHECK_HLSL: WaveMultiPrefixBitXor + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dTid : SV_DispatchThreadID) { - int idx = int(dispatchThreadID.x); - - int value = 0; - - uint4 mask = WaveMatch(true); - - // Scalar - value += WaveMultiPrefixSum(1 << idx, mask); - - outputBuffer[idx] = value; -} \ No newline at end of file + int scalarVal = dTid.x; + uint4 mask = WaveMatch(scalarVal); + + int scalarSum = WaveMultiPrefixSum(scalarVal, mask); + int scalarProduct = WaveMultiPrefixProduct(scalarVal, mask); + int scalarBitAnd = WaveMultiPrefixBitAnd(scalarVal, mask); + int scalarBitOr = WaveMultiPrefixBitOr(scalarVal, mask); + int scalarBitXor = WaveMultiPrefixBitXor(scalarVal, mask); + + int3 vectorVal = int3(dTid.x, dTid.y, dTid.z); + int3 vectorSum = WaveMultiPrefixSum(vectorVal, mask); + int3 vectorProduct = WaveMultiPrefixProduct(vectorVal, mask); + int3 vectorBitAnd = WaveMultiPrefixBitAnd(vectorVal, mask); + int3 vectorBitOr = WaveMultiPrefixBitOr(vectorVal, mask); + int3 vectorBitXor = WaveMultiPrefixBitXor(vectorVal, mask); + + float floatScalarVal = float(dTid.x) + 0.5f; // Example floating-point scalar value + uint4 floatMask = WaveMatch(floatScalarVal); // Create a mask for matching lanes + + float floatScalarSum = WaveMultiPrefixSum(floatScalarVal, floatMask); + float floatScalarProduct = WaveMultiPrefixProduct(floatScalarVal, floatMask); + + float3 floatVectorVal = float3(dTid.x, dTid.y, dTid.z) + 0.5f; // Example floating-point vector value + float3 floatVectorSum = WaveMultiPrefixSum(floatVectorVal, floatMask); + float3 floatVectorProduct = WaveMultiPrefixProduct(floatVectorVal, floatMask); + + OutputData output; + output.scalarSum = scalarSum; + output.scalarProduct = scalarProduct; + output.scalarBitAnd = scalarBitAnd; + output.scalarBitOr = scalarBitOr; + output.scalarBitXor = scalarBitXor; + output.vectorSum = vectorSum.x; + output.vectorProduct = vectorProduct.x; + output.vectorBitAnd = vectorBitAnd.x; + output.vectorBitOr = vectorBitOr.x; + output.vectorBitXor = vectorBitXor.x; + output.floatScalarSum = floatScalarSum; + output.floatScalarProduct = floatScalarProduct; + output.floatVectorSum = floatVectorSum.x; + output.floatVectorProduct = floatVectorProduct.x; + + float3x3 floatMatrixVal = float3x3( + float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, + float(dTid.z) + 0.5f, float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, + float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, float(dTid.x) + 0.5f + ); + float3x3 floatMatrixSum = WaveMultiPrefixSum(floatMatrixVal, floatMask); + float3x3 floatMatrixProduct = WaveMultiPrefixProduct(floatMatrixVal, floatMask); + output.floatMatrixSum = floatMatrixSum[0][0]; + output.floatMatrixProduct = floatMatrixProduct[0][0]; + +#if !defined(NO_INTEGER_MATRIX) + int3x3 matrixVal = int3x3( + dTid.x, dTid.y, dTid.z, + dTid.z, dTid.x, dTid.y, + dTid.y, dTid.z, dTid.x + ); + int3x3 matrixSum = WaveMultiPrefixSum(matrixVal, mask); + int3x3 matrixProduct = WaveMultiPrefixProduct(matrixVal, mask); + int3x3 matrixBitAnd = WaveMultiPrefixBitAnd(matrixVal, mask); + int3x3 matrixBitOr = WaveMultiPrefixBitOr(matrixVal, mask); + int3x3 matrixBitXor = WaveMultiPrefixBitXor(matrixVal, mask); + output.matrixSum = matrixSum[0][0]; + output.matrixProduct = matrixProduct[0][0]; + output.matrixBitAnd = matrixBitAnd[0][0]; + output.matrixBitOr = matrixBitOr[0][0]; + output.matrixBitXor = matrixBitXor[0][0]; +#endif + + outputBuffer[dTid.x] = output; +} + diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt deleted file mode 100644 index 6ec6deeea0..0000000000 --- a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt +++ /dev/null @@ -1,8 +0,0 @@ -0 -1 -3 -7 -F -1F -3F -7F diff --git a/tests/language-feature/capability/capabilitySimplification1.slang b/tests/language-feature/capability/capabilitySimplification1.slang index b694673e98..1d781a45e3 100644 --- a/tests/language-feature/capability/capabilitySimplification1.slang +++ b/tests/language-feature/capability/capabilitySimplification1.slang @@ -6,9 +6,9 @@ // CHECK: error 36107 // CHECK-SAME: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute' -// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixProduct' -// CHECK-NOT: see using of 'WaveMultiPrefixProduct' -// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct' +// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixCountBits' +// CHECK-NOT: see using of 'WaveMultiPrefixCountBits' +// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits' // CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require' void nestedSafeCall() @@ -18,7 +18,7 @@ void nestedSafeCall() void nestedBadCall() { - WaveMultiPrefixProduct(1, 0); + WaveMultiPrefixCountBits(true, 0); } void nestedCall() diff --git a/tests/language-feature/capability/capabilitySimplification3.slang b/tests/language-feature/capability/capabilitySimplification3.slang index faf161d15c..808c19bf60 100644 --- a/tests/language-feature/capability/capabilitySimplification3.slang +++ b/tests/language-feature/capability/capabilitySimplification3.slang @@ -5,13 +5,13 @@ // CHECK_IGNORE_CAPS-NOT: error 36107 // CHECK: error 36107: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute' -// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixProduct' -// CHECK-NOT: see using of 'WaveMultiPrefixProduct' -// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct' +// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixCountBits' +// CHECK-NOT: see using of 'WaveMultiPrefixCountBits' +// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits' // CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require' [numthreads(1,1,1)] void computeMain() { - WaveMultiPrefixProduct(1, 0); + WaveMultiPrefixCountBits(true, 0); }