From 4ad5a6959ffaae2e22490dc6ffea2b2c09091d16 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Thu, 22 Jun 2023 09:51:00 -0500 Subject: [PATCH 1/5] [SM68] Initial proposal for Wave Matrix This proposal adds new WaveMatrix data types that support lane cooperative matrix multiplication. A preview implementation of this proposal is available as part of the DXC 1.8.2306 preview release. --- proposals/xxxx-wave-matrix.md | 396 ++++++++++++++++++++++++++++++++++ 1 file changed, 396 insertions(+) create mode 100644 proposals/xxxx-wave-matrix.md diff --git a/proposals/xxxx-wave-matrix.md b/proposals/xxxx-wave-matrix.md new file mode 100644 index 00000000..4c7d7bbb --- /dev/null +++ b/proposals/xxxx-wave-matrix.md @@ -0,0 +1,396 @@ + + +# Wave Matrix + +* Proposal: [NNNN](NNNN-wave-matrix.md) +* Author(s): [Chris Bieneman](https://github.com/llvm-beanz) +* Sponsor: [Greg Roth](https://github.com/pow2clk), [Tex Riddell](https://github.com/tex3d) +* Status: **Under Consideration** +* Planned Version: Shader Model 6.8 + + +## Introduction + +This proposal adds HLSL support for a new DXIL feature WaveMatrix. The new data +types and operations add support for wave cooperative matrix multiplication and +accumulation. The underlying hardware and driver interfaces for this feature are +introduced in Shader Model 6.8. + +## Motivation + +GPUs have added dedicated hardware to support cooperative matrix multiplication +across SIMD lanes. This feature proposes adding new data types and built-in +functions to enable higher throughput matrix operations that fully utilize GPU +SIMD hardware. + +These higher throughput matrix operations are required for optimal performance +of many machine learning and image processing workloads. Adding native support +to HLSL will enable high-performance matrix operations across all supported +hardware with Shader Model 6.8 drivers. + +## Proposed solution + +WaveMatrix introduces new matrix templates to facilitate wave cooperative +operations: + +```c++ +// Matrix Depth (K) dimension is hardware dependent +// With TYPE_IN one of {float32_t, float16_t, uint8_t4_packed, int8_t4_packed} +WaveMatrixLeft ; // M x K +WaveMatrixRight ; // K x N + +// With TYPE_ACC one of {float32_t, float16_t, int32_t} +WaveMatrixAccumulator ; // M x N +// WaveMatrixLeftColAcc and WaveMatrixRightRowAcc are provided support for +// quantization algorithms. See Zero Point section + +// For accumulating columns from WaveMatrixLeft into a single column of sums +WaveMatrixLeftColAcc ; // M x 1 + +// For accumulating rows from WaveMatrixRight into a single row of sums +WaveMatrixRightRowAcc ; // 1 x N +``` + +WaveMatrix accumulator object methods support operating on corresponding Left +and Right operands. Results of operations can be stored or accumulated back into +the accumulator. A simple example of multiplication is: + +```c++ +[numthreads(64,1,1)] +void main(uint3 GTID : SV_GroupThreadID, uint GIDX : SV_GroupIndex) +{ + WaveMatrixLeft Left; + WaveMatrixRight Right; + WaveMatrixAccumulator Acc; + + Acc.Multiply(Left, Right); // Stores Left * Right into Acc + Acc.MultiplyAccumulate(Left, Right); // Adds Left * Right to Acc +} +``` + +## Detailed design + +### Reading This Spec + +The next few sections include HLSL object definitions written in HLSL 2021 +syntax and using inheritance to represent interface composition. The objects in +the `detail` namespace are not exposed in the HLSL runtime. They are provided +here to make the specification more concise to consume. Objects in no namespace +are exposed as public interfaces. + +### WaveMatrix Fill + +All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` +where `ElTy` is the element type. + +The `Fill` method fills the matrix or matrix fragment with the provided value. +All wave threads must provide the same value or the result is undefined. All +WaveMatrix objects have the same `Fill` method with the same behavior. + +### WaveMatrix Matrix Objects + +The code below approximately declares the base interface that WaveMatrix matrix +objects implement. + +```c++ +namespace detail { +template +class WaveMatrixBase { + void Fill(ElTy Val); + void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, + uint Align = 0); + void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + bool ColMajor, uint Align = sizeof(ElTy)); + + void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); + + void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + bool ColMajor, uint Align = sizeof(ElTy)); + + void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); +}; +} // namespace detail +``` + +#### Loading and Storing WaveMatrix Matrix Objects + +Values for WaveMatrix matrix objects can be loaded from and stored to +`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. + +Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in +bytes, row stride in bytes, and optional row alignment in bytes. Rows begin at +algined offsets from the `StartOffset` based on the `Stride` and `Alignment` +(`StartOffset + (RowIdx * [(Stride % Alignment) + Stride])`). + +The `Alignment` must be at least `sizeof(ElTy)` otherwise the load behavior is +undefined. + +Loading from or storing to a `groupshared` array takes the starting index, and +the row stride as a number of elements. + +##### Orientation + +When loading and storing WaveMatrix matrices a boolean parameter is provided to +indicate if the matrix being loaded or stored is column major. + +Matrices may be stored in row or column layout. Matrices are always loaded into +row-major orientation in memory for WaveMatrix objects. The `Load` and `Store` +methods perform matrix transposition when loading or storing column major +matrices. + +##### Stride + +When loading and storing from `groupshared` arrays, the stride is expressed in +number of elements. + +When loading and storing from `[RW]ByteAddressBuffer` types, the stride is +expressed in bytes. The row stride must be a multiple of the size of the +element, and greater than or equal to the size of the element multiplied by the +number of elements per row. Any value below the minimum legal values are +ignored. The behavior of row stride values that are not a multiple of element +stride is undefined. + +#### WaveMatrix Left & Right + +```c++ +template +class WaveMatrixLeft : detail::WaveMatrixBase { + uint MatrixDepth(); +}; + +template +class WaveMatrixRight : detail::WaveMatrixBase { + uint MatrixDepth(); +}; +``` + +`ElTy` must be either a 32 or 16 bit floating point type or an 8 bit packed +signed or unsigned integer type. `NRows` and `NCols` must be compile-time +constant expressions. + +##### WaveMatrix(Left|Right) MatrixDepth + +The `MatrixDepth` method returns the hardware-dependent depth for the matrix +multiplication unit. The resulting value must be an even multiple of 16. + +### WaveMatrix Matrix Fragment Objects + +The code below approximately declares the base interface that all WaveMatrix +fragment objects implement. A WaveMatrix fragment object stores a single row or +column of a WaveMatrix matrix. + +```c++ +namespace detail { +template +class WaveMatrixFragmentBase : WaveMatrixBase { + void Fill(ElTy Val); + + void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + + void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride); + + void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + + void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride); +}; +} // namespace detail +``` + +#### Loading and Storing WaveMatrix Fragment Objects + +Values for WaveMatrix fragment objects can be loaded from and stored to +`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. + +Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in +bytes, element stride in bytes, and optional alignment in bytes. The alignment +must be at least `sizeof(ElTy)` otherwise the load behavior is undefined. + +Loading from or storing to a `groupshared` array takes the starting index, and +the element stride as a number of elements. + +##### Stride + +When loading and storing from `groupshared` arrays, the stride is expressed in +number of elements. + +When loading and storing from `[RW]ByteAddressBuffer` types the stride must be +greater than or equal to the size of the element type. Any value below the +minimum legal values are ignored. The behavior of row stride values that are not +a multiple of element stride is undefined. + +### WaveMatrix Accumulator Objects + +The WaveMatrix Accumulator objects come in three forms which are represented by +two categories: matrix accumulators and fragment accumulators. All accumulators +implement the interface below: + +```c++ +namespace detail { +template +class WaveMatrixAccumulatorBase : BaseT { + + void ScalarMultiply(ElTy Value); + void ScalarDivide(ElTy Value); + void ScalarAdd(ElTy Value); + void ScalarSubtract(ElTy Value); +}; +} // namespace detail +``` + +Each of these operations performs the corresponding element-wise arithmetic +operation on the accumulator and the provided scalar value, storing the result +back into the accumulator. + +### WaveMatrixAccumulator + +```c++ +namespace detail { +template +struct is_8bit_packed_int_type = + std::enable_if_t<(std::is_same::value || + std::is_same::value), + std::true_type>; + +template struct is_8bit_packed_int_type = std::false_type; + +template +struct is_32bit_int_type = std::enable_if_t<(std::is_same::value || + std::is_same::value), + std::true_type>; + +template struct is_32bit_int_type = std::false_type; +} // namespace detail + +template +class WaveMatrixAccumulator + : WaveMatrixAccumulatorBase { + void Add(WaveMatrixLeftColAcc LeftMatrix); + void Add(WaveMatrixRightRowAcc RightMatrix); + void Add(WaveMatrixAccumulator Matrix); + + void Multiply(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + void MultiplyAccumulate(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + + template + std::enable_if_t && + detail::is_8bit_packed_int_type && + detail::is_8bit_packed_int_type, + void>::type + Multiply(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + + template + std::enable_if_t && + detail::is_8bit_packed_int_type && + detail::is_8bit_packed_int_type, + void>::type + MultiplyAccumulate(WaveMatrixLeft LHS, + WaveMatrixRight RHS); +}; + +``` + +#### Broadcast Add + +The `WaveMatrixAccumulator::Add` methods perform an element-wise addition of an +accumulator matrix and the provided matrix or fragment accumulator. + +If the argument is a fragment object, the `Add` method broadcasts the argument +accumulator up from an Mx1 or 1xN matrix to an MxN matrix then performs +element-wise addition. + +#### WaveMatrixAccumulator::Multiply + +The `WaveMatrixAccumulator::Multiply` method performs multiplication of the left +and right arguments and stores the result back into the `WaveMatrixAccumulator`. +This is a wave-level operation and cannot be used inside divergent control flow. + +For `Multiply` operations the matrix element types must match the accumulator +type unless the accumulator is a signed or unsigned 32-bit integer type. For +signed or unsigned 32-bit integer accumulators, signed or unsigned 8-bit packed +integers are also supported and can be mixed interchangeably. + +#### WaveMatrixAccumulator::MultiplyAccumulate + +The `WaveMatrixAccumulator::MultiplyAccumulate` method performs multiplication +of the left and right arguments and adds the result back into the +`WaveMatrixAccumulator`. This is a wave-level operation and cannot be used +inside divergent control flow. + +For `MultiplyAccumulate` operations the matrix element types must match the +accumulator type unless the accumulator is a signed or unsigned 32-bit integer +type. For signed or unsigned 32-bit integer accumulators, signed or unsigned +8-bit packed integers are also supported and can be mixed interchangeably. + +### WaveMatrix Fragment Accumulators + +WaveMatrix intrinsics are defined to support quantization calculations. +Including calculating a sum for the rows of the left matrix and a sum of the +columns of the right matrix. The `WaveMatrixRightRowAcc` and +`WaveMatrixLeftColAcc` fragment accumulators perform this operation. + +```c++ +template +class WaveMatrixLeftColAcc + : WaveMatrixAccumulatorBase { + template + void SumAccumulate(WaveMatrixLeft LeftMatrix); +}; + +template +class WaveMatrixRightRowAcc + : WaveMatrixAccumulatorBase { + void SumAccumulate(WaveMatrixRight RightMatrix); +}; +``` + +#### Zero Point + +The following is the equation for matrix multiplication with zero point +adjustment included: + +$C_{[x,y]} = (\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]}) - Z_a * (\sum_{i=0}^{K} B_{[i,y]}) - Z_b * (\sum_{i=0}^{K} A_{[x,i]}) + Z_a * Z_b * K$ + +$(\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]})$ is basic matrix multiplication. + +$- Z_a * (\sum_{i=0}^{K} B_{[i,y]})$ is the zero point adjustment for matrix $A$ + +$- Z_b * (\sum_{i=0}^{K} A_{[x,i]})$ is the zero point adjustment for matrix $B$ + +$+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B$ + +$Z_*$ are constant zero points values + +#### Wave Matrix SumAccumulate + +The `SumAccumulate` methods accumulate the values of the argument matrix into +the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same +data type as the fragment accumulator. + +This intrinsics is used to calculated +$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from our above +equation. + + +## Acknowledgments + +This spec was developed as an extensive collaboration between the Microsoft HLSL +and Direct3D teams and IHV partners. + +Special thanks to: + +Claire Andrews +Nick Feeney +Amar Patel +Tex Riddell +Greg Roth + + From 3cea3831025bd0efd39b93aef209745f8442f902 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 22 Aug 2023 12:10:30 -0500 Subject: [PATCH 2/5] Renormalize line endings. --- proposals/xxxx-wave-matrix.md | 792 +++++++++++++++++----------------- 1 file changed, 396 insertions(+), 396 deletions(-) diff --git a/proposals/xxxx-wave-matrix.md b/proposals/xxxx-wave-matrix.md index 4c7d7bbb..b57e01ea 100644 --- a/proposals/xxxx-wave-matrix.md +++ b/proposals/xxxx-wave-matrix.md @@ -1,396 +1,396 @@ - - -# Wave Matrix - -* Proposal: [NNNN](NNNN-wave-matrix.md) -* Author(s): [Chris Bieneman](https://github.com/llvm-beanz) -* Sponsor: [Greg Roth](https://github.com/pow2clk), [Tex Riddell](https://github.com/tex3d) -* Status: **Under Consideration** -* Planned Version: Shader Model 6.8 - - -## Introduction - -This proposal adds HLSL support for a new DXIL feature WaveMatrix. The new data -types and operations add support for wave cooperative matrix multiplication and -accumulation. The underlying hardware and driver interfaces for this feature are -introduced in Shader Model 6.8. - -## Motivation - -GPUs have added dedicated hardware to support cooperative matrix multiplication -across SIMD lanes. This feature proposes adding new data types and built-in -functions to enable higher throughput matrix operations that fully utilize GPU -SIMD hardware. - -These higher throughput matrix operations are required for optimal performance -of many machine learning and image processing workloads. Adding native support -to HLSL will enable high-performance matrix operations across all supported -hardware with Shader Model 6.8 drivers. - -## Proposed solution - -WaveMatrix introduces new matrix templates to facilitate wave cooperative -operations: - -```c++ -// Matrix Depth (K) dimension is hardware dependent -// With TYPE_IN one of {float32_t, float16_t, uint8_t4_packed, int8_t4_packed} -WaveMatrixLeft ; // M x K -WaveMatrixRight ; // K x N - -// With TYPE_ACC one of {float32_t, float16_t, int32_t} -WaveMatrixAccumulator ; // M x N -// WaveMatrixLeftColAcc and WaveMatrixRightRowAcc are provided support for -// quantization algorithms. See Zero Point section - -// For accumulating columns from WaveMatrixLeft into a single column of sums -WaveMatrixLeftColAcc ; // M x 1 - -// For accumulating rows from WaveMatrixRight into a single row of sums -WaveMatrixRightRowAcc ; // 1 x N -``` - -WaveMatrix accumulator object methods support operating on corresponding Left -and Right operands. Results of operations can be stored or accumulated back into -the accumulator. A simple example of multiplication is: - -```c++ -[numthreads(64,1,1)] -void main(uint3 GTID : SV_GroupThreadID, uint GIDX : SV_GroupIndex) -{ - WaveMatrixLeft Left; - WaveMatrixRight Right; - WaveMatrixAccumulator Acc; - - Acc.Multiply(Left, Right); // Stores Left * Right into Acc - Acc.MultiplyAccumulate(Left, Right); // Adds Left * Right to Acc -} -``` - -## Detailed design - -### Reading This Spec - -The next few sections include HLSL object definitions written in HLSL 2021 -syntax and using inheritance to represent interface composition. The objects in -the `detail` namespace are not exposed in the HLSL runtime. They are provided -here to make the specification more concise to consume. Objects in no namespace -are exposed as public interfaces. - -### WaveMatrix Fill - -All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` -where `ElTy` is the element type. - -The `Fill` method fills the matrix or matrix fragment with the provided value. -All wave threads must provide the same value or the result is undefined. All -WaveMatrix objects have the same `Fill` method with the same behavior. - -### WaveMatrix Matrix Objects - -The code below approximately declares the base interface that WaveMatrix matrix -objects implement. - -```c++ -namespace detail { -template -class WaveMatrixBase { - void Fill(ElTy Val); - void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, - uint Align = 0); - void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, - bool ColMajor, uint Align = sizeof(ElTy)); - - void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); - - void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, - bool ColMajor, uint Align = sizeof(ElTy)); - - void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); -}; -} // namespace detail -``` - -#### Loading and Storing WaveMatrix Matrix Objects - -Values for WaveMatrix matrix objects can be loaded from and stored to -`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. - -Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in -bytes, row stride in bytes, and optional row alignment in bytes. Rows begin at -algined offsets from the `StartOffset` based on the `Stride` and `Alignment` -(`StartOffset + (RowIdx * [(Stride % Alignment) + Stride])`). - -The `Alignment` must be at least `sizeof(ElTy)` otherwise the load behavior is -undefined. - -Loading from or storing to a `groupshared` array takes the starting index, and -the row stride as a number of elements. - -##### Orientation - -When loading and storing WaveMatrix matrices a boolean parameter is provided to -indicate if the matrix being loaded or stored is column major. - -Matrices may be stored in row or column layout. Matrices are always loaded into -row-major orientation in memory for WaveMatrix objects. The `Load` and `Store` -methods perform matrix transposition when loading or storing column major -matrices. - -##### Stride - -When loading and storing from `groupshared` arrays, the stride is expressed in -number of elements. - -When loading and storing from `[RW]ByteAddressBuffer` types, the stride is -expressed in bytes. The row stride must be a multiple of the size of the -element, and greater than or equal to the size of the element multiplied by the -number of elements per row. Any value below the minimum legal values are -ignored. The behavior of row stride values that are not a multiple of element -stride is undefined. - -#### WaveMatrix Left & Right - -```c++ -template -class WaveMatrixLeft : detail::WaveMatrixBase { - uint MatrixDepth(); -}; - -template -class WaveMatrixRight : detail::WaveMatrixBase { - uint MatrixDepth(); -}; -``` - -`ElTy` must be either a 32 or 16 bit floating point type or an 8 bit packed -signed or unsigned integer type. `NRows` and `NCols` must be compile-time -constant expressions. - -##### WaveMatrix(Left|Right) MatrixDepth - -The `MatrixDepth` method returns the hardware-dependent depth for the matrix -multiplication unit. The resulting value must be an even multiple of 16. - -### WaveMatrix Matrix Fragment Objects - -The code below approximately declares the base interface that all WaveMatrix -fragment objects implement. A WaveMatrix fragment object stores a single row or -column of a WaveMatrix matrix. - -```c++ -namespace detail { -template -class WaveMatrixFragmentBase : WaveMatrixBase { - void Fill(ElTy Val); - - void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, - uint Align = 0); - void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, - uint Align = 0); - - void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride); - - void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, - uint Align = 0); - - void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride); -}; -} // namespace detail -``` - -#### Loading and Storing WaveMatrix Fragment Objects - -Values for WaveMatrix fragment objects can be loaded from and stored to -`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. - -Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in -bytes, element stride in bytes, and optional alignment in bytes. The alignment -must be at least `sizeof(ElTy)` otherwise the load behavior is undefined. - -Loading from or storing to a `groupshared` array takes the starting index, and -the element stride as a number of elements. - -##### Stride - -When loading and storing from `groupshared` arrays, the stride is expressed in -number of elements. - -When loading and storing from `[RW]ByteAddressBuffer` types the stride must be -greater than or equal to the size of the element type. Any value below the -minimum legal values are ignored. The behavior of row stride values that are not -a multiple of element stride is undefined. - -### WaveMatrix Accumulator Objects - -The WaveMatrix Accumulator objects come in three forms which are represented by -two categories: matrix accumulators and fragment accumulators. All accumulators -implement the interface below: - -```c++ -namespace detail { -template -class WaveMatrixAccumulatorBase : BaseT { - - void ScalarMultiply(ElTy Value); - void ScalarDivide(ElTy Value); - void ScalarAdd(ElTy Value); - void ScalarSubtract(ElTy Value); -}; -} // namespace detail -``` - -Each of these operations performs the corresponding element-wise arithmetic -operation on the accumulator and the provided scalar value, storing the result -back into the accumulator. - -### WaveMatrixAccumulator - -```c++ -namespace detail { -template -struct is_8bit_packed_int_type = - std::enable_if_t<(std::is_same::value || - std::is_same::value), - std::true_type>; - -template struct is_8bit_packed_int_type = std::false_type; - -template -struct is_32bit_int_type = std::enable_if_t<(std::is_same::value || - std::is_same::value), - std::true_type>; - -template struct is_32bit_int_type = std::false_type; -} // namespace detail - -template -class WaveMatrixAccumulator - : WaveMatrixAccumulatorBase { - void Add(WaveMatrixLeftColAcc LeftMatrix); - void Add(WaveMatrixRightRowAcc RightMatrix); - void Add(WaveMatrixAccumulator Matrix); - - void Multiply(WaveMatrixLeft LHS, - WaveMatrixRight RHS); - void MultiplyAccumulate(WaveMatrixLeft LHS, - WaveMatrixRight RHS); - - template - std::enable_if_t && - detail::is_8bit_packed_int_type && - detail::is_8bit_packed_int_type, - void>::type - Multiply(WaveMatrixLeft LHS, - WaveMatrixRight RHS); - - template - std::enable_if_t && - detail::is_8bit_packed_int_type && - detail::is_8bit_packed_int_type, - void>::type - MultiplyAccumulate(WaveMatrixLeft LHS, - WaveMatrixRight RHS); -}; - -``` - -#### Broadcast Add - -The `WaveMatrixAccumulator::Add` methods perform an element-wise addition of an -accumulator matrix and the provided matrix or fragment accumulator. - -If the argument is a fragment object, the `Add` method broadcasts the argument -accumulator up from an Mx1 or 1xN matrix to an MxN matrix then performs -element-wise addition. - -#### WaveMatrixAccumulator::Multiply - -The `WaveMatrixAccumulator::Multiply` method performs multiplication of the left -and right arguments and stores the result back into the `WaveMatrixAccumulator`. -This is a wave-level operation and cannot be used inside divergent control flow. - -For `Multiply` operations the matrix element types must match the accumulator -type unless the accumulator is a signed or unsigned 32-bit integer type. For -signed or unsigned 32-bit integer accumulators, signed or unsigned 8-bit packed -integers are also supported and can be mixed interchangeably. - -#### WaveMatrixAccumulator::MultiplyAccumulate - -The `WaveMatrixAccumulator::MultiplyAccumulate` method performs multiplication -of the left and right arguments and adds the result back into the -`WaveMatrixAccumulator`. This is a wave-level operation and cannot be used -inside divergent control flow. - -For `MultiplyAccumulate` operations the matrix element types must match the -accumulator type unless the accumulator is a signed or unsigned 32-bit integer -type. For signed or unsigned 32-bit integer accumulators, signed or unsigned -8-bit packed integers are also supported and can be mixed interchangeably. - -### WaveMatrix Fragment Accumulators - -WaveMatrix intrinsics are defined to support quantization calculations. -Including calculating a sum for the rows of the left matrix and a sum of the -columns of the right matrix. The `WaveMatrixRightRowAcc` and -`WaveMatrixLeftColAcc` fragment accumulators perform this operation. - -```c++ -template -class WaveMatrixLeftColAcc - : WaveMatrixAccumulatorBase { - template - void SumAccumulate(WaveMatrixLeft LeftMatrix); -}; - -template -class WaveMatrixRightRowAcc - : WaveMatrixAccumulatorBase { - void SumAccumulate(WaveMatrixRight RightMatrix); -}; -``` - -#### Zero Point - -The following is the equation for matrix multiplication with zero point -adjustment included: - -$C_{[x,y]} = (\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]}) - Z_a * (\sum_{i=0}^{K} B_{[i,y]}) - Z_b * (\sum_{i=0}^{K} A_{[x,i]}) + Z_a * Z_b * K$ - -$(\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]})$ is basic matrix multiplication. - -$- Z_a * (\sum_{i=0}^{K} B_{[i,y]})$ is the zero point adjustment for matrix $A$ - -$- Z_b * (\sum_{i=0}^{K} A_{[x,i]})$ is the zero point adjustment for matrix $B$ - -$+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B$ - -$Z_*$ are constant zero points values - -#### Wave Matrix SumAccumulate - -The `SumAccumulate` methods accumulate the values of the argument matrix into -the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same -data type as the fragment accumulator. - -This intrinsics is used to calculated -$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from our above -equation. - - -## Acknowledgments - -This spec was developed as an extensive collaboration between the Microsoft HLSL -and Direct3D teams and IHV partners. - -Special thanks to: - -Claire Andrews -Nick Feeney -Amar Patel -Tex Riddell -Greg Roth - - + + +# Wave Matrix + +* Proposal: [NNNN](NNNN-wave-matrix.md) +* Author(s): [Chris Bieneman](https://github.com/llvm-beanz) +* Sponsor: [Greg Roth](https://github.com/pow2clk), [Tex Riddell](https://github.com/tex3d) +* Status: **Under Consideration** +* Planned Version: Shader Model 6.8 + + +## Introduction + +This proposal adds HLSL support for a new DXIL feature WaveMatrix. The new data +types and operations add support for wave cooperative matrix multiplication and +accumulation. The underlying hardware and driver interfaces for this feature are +introduced in Shader Model 6.8. + +## Motivation + +GPUs have added dedicated hardware to support cooperative matrix multiplication +across SIMD lanes. This feature proposes adding new data types and built-in +functions to enable higher throughput matrix operations that fully utilize GPU +SIMD hardware. + +These higher throughput matrix operations are required for optimal performance +of many machine learning and image processing workloads. Adding native support +to HLSL will enable high-performance matrix operations across all supported +hardware with Shader Model 6.8 drivers. + +## Proposed solution + +WaveMatrix introduces new matrix templates to facilitate wave cooperative +operations: + +```c++ +// Matrix Depth (K) dimension is hardware dependent +// With TYPE_IN one of {float32_t, float16_t, uint8_t4_packed, int8_t4_packed} +WaveMatrixLeft ; // M x K +WaveMatrixRight ; // K x N + +// With TYPE_ACC one of {float32_t, float16_t, int32_t} +WaveMatrixAccumulator ; // M x N +// WaveMatrixLeftColAcc and WaveMatrixRightRowAcc are provided support for +// quantization algorithms. See Zero Point section + +// For accumulating columns from WaveMatrixLeft into a single column of sums +WaveMatrixLeftColAcc ; // M x 1 + +// For accumulating rows from WaveMatrixRight into a single row of sums +WaveMatrixRightRowAcc ; // 1 x N +``` + +WaveMatrix accumulator object methods support operating on corresponding Left +and Right operands. Results of operations can be stored or accumulated back into +the accumulator. A simple example of multiplication is: + +```c++ +[numthreads(64,1,1)] +void main(uint3 GTID : SV_GroupThreadID, uint GIDX : SV_GroupIndex) +{ + WaveMatrixLeft Left; + WaveMatrixRight Right; + WaveMatrixAccumulator Acc; + + Acc.Multiply(Left, Right); // Stores Left * Right into Acc + Acc.MultiplyAccumulate(Left, Right); // Adds Left * Right to Acc +} +``` + +## Detailed design + +### Reading This Spec + +The next few sections include HLSL object definitions written in HLSL 2021 +syntax and using inheritance to represent interface composition. The objects in +the `detail` namespace are not exposed in the HLSL runtime. They are provided +here to make the specification more concise to consume. Objects in no namespace +are exposed as public interfaces. + +### WaveMatrix Fill + +All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` +where `ElTy` is the element type. + +The `Fill` method fills the matrix or matrix fragment with the provided value. +All wave threads must provide the same value or the result is undefined. All +WaveMatrix objects have the same `Fill` method with the same behavior. + +### WaveMatrix Matrix Objects + +The code below approximately declares the base interface that WaveMatrix matrix +objects implement. + +```c++ +namespace detail { +template +class WaveMatrixBase { + void Fill(ElTy Val); + void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, bool ColMajor, + uint Align = 0); + void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + bool ColMajor, uint Align = sizeof(ElTy)); + + void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); + + void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + bool ColMajor, uint Align = sizeof(ElTy)); + + void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride, bool ColMajor); +}; +} // namespace detail +``` + +#### Loading and Storing WaveMatrix Matrix Objects + +Values for WaveMatrix matrix objects can be loaded from and stored to +`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. + +Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in +bytes, row stride in bytes, and optional row alignment in bytes. Rows begin at +algined offsets from the `StartOffset` based on the `Stride` and `Alignment` +(`StartOffset + (RowIdx * [(Stride % Alignment) + Stride])`). + +The `Alignment` must be at least `sizeof(ElTy)` otherwise the load behavior is +undefined. + +Loading from or storing to a `groupshared` array takes the starting index, and +the row stride as a number of elements. + +##### Orientation + +When loading and storing WaveMatrix matrices a boolean parameter is provided to +indicate if the matrix being loaded or stored is column major. + +Matrices may be stored in row or column layout. Matrices are always loaded into +row-major orientation in memory for WaveMatrix objects. The `Load` and `Store` +methods perform matrix transposition when loading or storing column major +matrices. + +##### Stride + +When loading and storing from `groupshared` arrays, the stride is expressed in +number of elements. + +When loading and storing from `[RW]ByteAddressBuffer` types, the stride is +expressed in bytes. The row stride must be a multiple of the size of the +element, and greater than or equal to the size of the element multiplied by the +number of elements per row. Any value below the minimum legal values are +ignored. The behavior of row stride values that are not a multiple of element +stride is undefined. + +#### WaveMatrix Left & Right + +```c++ +template +class WaveMatrixLeft : detail::WaveMatrixBase { + uint MatrixDepth(); +}; + +template +class WaveMatrixRight : detail::WaveMatrixBase { + uint MatrixDepth(); +}; +``` + +`ElTy` must be either a 32 or 16 bit floating point type or an 8 bit packed +signed or unsigned integer type. `NRows` and `NCols` must be compile-time +constant expressions. + +##### WaveMatrix(Left|Right) MatrixDepth + +The `MatrixDepth` method returns the hardware-dependent depth for the matrix +multiplication unit. The resulting value must be an even multiple of 16. + +### WaveMatrix Matrix Fragment Objects + +The code below approximately declares the base interface that all WaveMatrix +fragment objects implement. A WaveMatrix fragment object stores a single row or +column of a WaveMatrix matrix. + +```c++ +namespace detail { +template +class WaveMatrixFragmentBase : WaveMatrixBase { + void Fill(ElTy Val); + + void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + void Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + + void Load(groupshared ElTy Arr[], uint StartIdx, uint Stride); + + void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride, + uint Align = 0); + + void Store(groupshared ElTy Arr[], uint StartIdx, uint Stride); +}; +} // namespace detail +``` + +#### Loading and Storing WaveMatrix Fragment Objects + +Values for WaveMatrix fragment objects can be loaded from and stored to +`[RW]ByteAddressBuffer` or `groupshared` array of type `ElTy`. + +Loading from or storing to a `[RW]ByteAddressBuffer` takes a start offset in +bytes, element stride in bytes, and optional alignment in bytes. The alignment +must be at least `sizeof(ElTy)` otherwise the load behavior is undefined. + +Loading from or storing to a `groupshared` array takes the starting index, and +the element stride as a number of elements. + +##### Stride + +When loading and storing from `groupshared` arrays, the stride is expressed in +number of elements. + +When loading and storing from `[RW]ByteAddressBuffer` types the stride must be +greater than or equal to the size of the element type. Any value below the +minimum legal values are ignored. The behavior of row stride values that are not +a multiple of element stride is undefined. + +### WaveMatrix Accumulator Objects + +The WaveMatrix Accumulator objects come in three forms which are represented by +two categories: matrix accumulators and fragment accumulators. All accumulators +implement the interface below: + +```c++ +namespace detail { +template +class WaveMatrixAccumulatorBase : BaseT { + + void ScalarMultiply(ElTy Value); + void ScalarDivide(ElTy Value); + void ScalarAdd(ElTy Value); + void ScalarSubtract(ElTy Value); +}; +} // namespace detail +``` + +Each of these operations performs the corresponding element-wise arithmetic +operation on the accumulator and the provided scalar value, storing the result +back into the accumulator. + +### WaveMatrixAccumulator + +```c++ +namespace detail { +template +struct is_8bit_packed_int_type = + std::enable_if_t<(std::is_same::value || + std::is_same::value), + std::true_type>; + +template struct is_8bit_packed_int_type = std::false_type; + +template +struct is_32bit_int_type = std::enable_if_t<(std::is_same::value || + std::is_same::value), + std::true_type>; + +template struct is_32bit_int_type = std::false_type; +} // namespace detail + +template +class WaveMatrixAccumulator + : WaveMatrixAccumulatorBase { + void Add(WaveMatrixLeftColAcc LeftMatrix); + void Add(WaveMatrixRightRowAcc RightMatrix); + void Add(WaveMatrixAccumulator Matrix); + + void Multiply(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + void MultiplyAccumulate(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + + template + std::enable_if_t && + detail::is_8bit_packed_int_type && + detail::is_8bit_packed_int_type, + void>::type + Multiply(WaveMatrixLeft LHS, + WaveMatrixRight RHS); + + template + std::enable_if_t && + detail::is_8bit_packed_int_type && + detail::is_8bit_packed_int_type, + void>::type + MultiplyAccumulate(WaveMatrixLeft LHS, + WaveMatrixRight RHS); +}; + +``` + +#### Broadcast Add + +The `WaveMatrixAccumulator::Add` methods perform an element-wise addition of an +accumulator matrix and the provided matrix or fragment accumulator. + +If the argument is a fragment object, the `Add` method broadcasts the argument +accumulator up from an Mx1 or 1xN matrix to an MxN matrix then performs +element-wise addition. + +#### WaveMatrixAccumulator::Multiply + +The `WaveMatrixAccumulator::Multiply` method performs multiplication of the left +and right arguments and stores the result back into the `WaveMatrixAccumulator`. +This is a wave-level operation and cannot be used inside divergent control flow. + +For `Multiply` operations the matrix element types must match the accumulator +type unless the accumulator is a signed or unsigned 32-bit integer type. For +signed or unsigned 32-bit integer accumulators, signed or unsigned 8-bit packed +integers are also supported and can be mixed interchangeably. + +#### WaveMatrixAccumulator::MultiplyAccumulate + +The `WaveMatrixAccumulator::MultiplyAccumulate` method performs multiplication +of the left and right arguments and adds the result back into the +`WaveMatrixAccumulator`. This is a wave-level operation and cannot be used +inside divergent control flow. + +For `MultiplyAccumulate` operations the matrix element types must match the +accumulator type unless the accumulator is a signed or unsigned 32-bit integer +type. For signed or unsigned 32-bit integer accumulators, signed or unsigned +8-bit packed integers are also supported and can be mixed interchangeably. + +### WaveMatrix Fragment Accumulators + +WaveMatrix intrinsics are defined to support quantization calculations. +Including calculating a sum for the rows of the left matrix and a sum of the +columns of the right matrix. The `WaveMatrixRightRowAcc` and +`WaveMatrixLeftColAcc` fragment accumulators perform this operation. + +```c++ +template +class WaveMatrixLeftColAcc + : WaveMatrixAccumulatorBase { + template + void SumAccumulate(WaveMatrixLeft LeftMatrix); +}; + +template +class WaveMatrixRightRowAcc + : WaveMatrixAccumulatorBase { + void SumAccumulate(WaveMatrixRight RightMatrix); +}; +``` + +#### Zero Point + +The following is the equation for matrix multiplication with zero point +adjustment included: + +$C_{[x,y]} = (\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]}) - Z_a * (\sum_{i=0}^{K} B_{[i,y]}) - Z_b * (\sum_{i=0}^{K} A_{[x,i]}) + Z_a * Z_b * K$ + +$(\sum_{i=0}^{K} A_{[x,i]} * B_{[i,y]})$ is basic matrix multiplication. + +$- Z_a * (\sum_{i=0}^{K} B_{[i,y]})$ is the zero point adjustment for matrix $A$ + +$- Z_b * (\sum_{i=0}^{K} A_{[x,i]})$ is the zero point adjustment for matrix $B$ + +$+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B$ + +$Z_*$ are constant zero points values + +#### Wave Matrix SumAccumulate + +The `SumAccumulate` methods accumulate the values of the argument matrix into +the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same +data type as the fragment accumulator. + +This intrinsics is used to calculated +$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from our above +equation. + + +## Acknowledgments + +This spec was developed as an extensive collaboration between the Microsoft HLSL +and Direct3D teams and IHV partners. + +Special thanks to: + +Claire Andrews +Nick Feeney +Amar Patel +Tex Riddell +Greg Roth + + From eb479ced08375285cd6132574b43724b461527ae Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 22 Aug 2023 15:40:45 -0500 Subject: [PATCH 3/5] Update based on review feedback. --- proposals/xxxx-wave-matrix.md | 99 ++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 41 deletions(-) diff --git a/proposals/xxxx-wave-matrix.md b/proposals/xxxx-wave-matrix.md index b57e01ea..3541705f 100644 --- a/proposals/xxxx-wave-matrix.md +++ b/proposals/xxxx-wave-matrix.md @@ -24,16 +24,16 @@ functions to enable higher throughput matrix operations that fully utilize GPU SIMD hardware. These higher throughput matrix operations are required for optimal performance -of many machine learning and image processing workloads. Adding native support -to HLSL will enable high-performance matrix operations across all supported -hardware with Shader Model 6.8 drivers. +of many machine learning and image processing workloads. Adding support to HLSL +will enable high-performance matrix operations across Shader Model 6.8 drivers +when hardware support is available. ## Proposed solution WaveMatrix introduces new matrix templates to facilitate wave cooperative operations: -```c++ +```hlsl // Matrix Depth (K) dimension is hardware dependent // With TYPE_IN one of {float32_t, float16_t, uint8_t4_packed, int8_t4_packed} WaveMatrixLeft ; // M x K @@ -55,7 +55,7 @@ WaveMatrix accumulator object methods support operating on corresponding Left and Right operands. Results of operations can be stored or accumulated back into the accumulator. A simple example of multiplication is: -```c++ +```hlsl [numthreads(64,1,1)] void main(uint3 GTID : SV_GroupThreadID, uint GIDX : SV_GroupIndex) { @@ -78,15 +78,6 @@ the `detail` namespace are not exposed in the HLSL runtime. They are provided here to make the specification more concise to consume. Objects in no namespace are exposed as public interfaces. -### WaveMatrix Fill - -All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` -where `ElTy` is the element type. - -The `Fill` method fills the matrix or matrix fragment with the provided value. -All wave threads must provide the same value or the result is undefined. All -WaveMatrix objects have the same `Fill` method with the same behavior. - ### WaveMatrix Matrix Objects The code below approximately declares the base interface that WaveMatrix matrix @@ -112,6 +103,20 @@ class WaveMatrixBase { } // namespace detail ``` +The following sections will explain in detail what these methods do and what the +parameters represent. + +#### WaveMatrix Fill + +All WaveMatrix objects have a `Fill` method of the form `void Fill(ElTy Value)` +where `ElTy` is the element type. + +The `Fill` method assigns the given Value to every element in the matrix or +matrix fragment. All wave threads must provide the same value or the result is +undefined. All WaveMatrix objects have the same `Fill` method with the same +behavior. + + #### Loading and Storing WaveMatrix Matrix Objects Values for WaveMatrix matrix objects can be loaded from and stored to @@ -131,7 +136,8 @@ the row stride as a number of elements. ##### Orientation When loading and storing WaveMatrix matrices a boolean parameter is provided to -indicate if the matrix being loaded or stored is column major. +indicate if the matrix being loaded or stored is column major. When false row +major orientation is assumed. Matrices may be stored in row or column layout. Matrices are always loaded into row-major orientation in memory for WaveMatrix objects. The `Load` and `Store` @@ -141,17 +147,21 @@ matrices. ##### Stride When loading and storing from `groupshared` arrays, the stride is expressed in -number of elements. +number of elements, and signifies the number of elements between the first +element of each row for row major matrices or column for column major matrices. When loading and storing from `[RW]ByteAddressBuffer` types, the stride is expressed in bytes. The row stride must be a multiple of the size of the -element, and greater than or equal to the size of the element multiplied by the +element and greater than or equal to the size of the element multiplied by the number of elements per row. Any value below the minimum legal values are ignored. The behavior of row stride values that are not a multiple of element stride is undefined. #### WaveMatrix Left & Right +The code below approximately declares the base interface that WaveMatrixLeft and +WaveMatrixRight objects implement. + ```c++ template class WaveMatrixLeft : detail::WaveMatrixBase { @@ -166,12 +176,13 @@ class WaveMatrixRight : detail::WaveMatrixBase { `ElTy` must be either a 32 or 16 bit floating point type or an 8 bit packed signed or unsigned integer type. `NRows` and `NCols` must be compile-time -constant expressions. +constant expressions, and represent the number of rows and columns in the matrix +respectively. ##### WaveMatrix(Left|Right) MatrixDepth The `MatrixDepth` method returns the hardware-dependent depth for the matrix -multiplication unit. The resulting value must be an even multiple of 16. +multiplication unit. The resulting value must be an multiple of 16. ### WaveMatrix Matrix Fragment Objects @@ -182,7 +193,7 @@ column of a WaveMatrix matrix. ```c++ namespace detail { template -class WaveMatrixFragmentBase : WaveMatrixBase { +class WaveMatrixFragmentBase { void Fill(ElTy Val); void Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, @@ -217,15 +228,15 @@ the element stride as a number of elements. When loading and storing from `groupshared` arrays, the stride is expressed in number of elements. -When loading and storing from `[RW]ByteAddressBuffer` types the stride must be -greater than or equal to the size of the element type. Any value below the -minimum legal values are ignored. The behavior of row stride values that are not -a multiple of element stride is undefined. +When loading and storing from `[RW]ByteAddressBuffer` types the stride is +expressed in bytes and must be greater than or equal to the size of the element +type. Any value below the minimum legal values are ignored. The behavior of +stride values that are not a multiple of element size is undefined. ### WaveMatrix Accumulator Objects -The WaveMatrix Accumulator objects come in three forms which are represented by -two categories: matrix accumulators and fragment accumulators. All accumulators +The WaveMatrix Accumulator objects come in three forms which fall into two +categories: matrix accumulators and fragment accumulators. All accumulators implement the interface below: ```c++ @@ -255,14 +266,14 @@ struct is_8bit_packed_int_type = std::is_same::value), std::true_type>; -template struct is_8bit_packed_int_type = std::false_type; +template using is_8bit_packed_int_type = std::false_type; template -struct is_32bit_int_type = std::enable_if_t<(std::is_same::value || +using is_32bit_int_type = std::enable_if_t<(std::is_same::value || std::is_same::value), std::true_type>; -template struct is_32bit_int_type = std::false_type; +template using is_32bit_int_type = std::false_type; } // namespace detail template @@ -309,7 +320,9 @@ element-wise addition. The `WaveMatrixAccumulator::Multiply` method performs multiplication of the left and right arguments and stores the result back into the `WaveMatrixAccumulator`. -This is a wave-level operation and cannot be used inside divergent control flow. +This operation does not accumulate into the result, it overwrites the result. +This is a wave-level operation and cannot be used inside divergent control flow: +doing so results in undefined behavior. For `Multiply` operations the matrix element types must match the accumulator type unless the accumulator is a signed or unsigned 32-bit integer type. For @@ -321,7 +334,7 @@ integers are also supported and can be mixed interchangeably. The `WaveMatrixAccumulator::MultiplyAccumulate` method performs multiplication of the left and right arguments and adds the result back into the `WaveMatrixAccumulator`. This is a wave-level operation and cannot be used -inside divergent control flow. +inside divergent control flow: doing so results in undefined behavior. For `MultiplyAccumulate` operations the matrix element types must match the accumulator type unless the accumulator is a signed or unsigned 32-bit integer @@ -352,6 +365,20 @@ class WaveMatrixRightRowAcc }; ``` +#### Wave Matrix Fragment SumAccumulate + +The `SumAccumulate` methods accumulate the values of the argument matrix into +the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same +data type as the fragment accumulator. + +This intrinsic is used to calculate +$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from the +equation in the **Zero Point** section below. + +For accumulating into a `WaveMatrixLeftColAcc`, this results in a row-wise sum +into the accumulator. For accumulating into a `WaveMatrixRightRowAcc`, this +results in a column-wise sum into the accumulator. + #### Zero Point The following is the equation for matrix multiplication with zero point @@ -369,16 +396,6 @@ $+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B $Z_*$ are constant zero points values -#### Wave Matrix SumAccumulate - -The `SumAccumulate` methods accumulate the values of the argument matrix into -the WaveMatrix fragment accumulator. The fragment WaveMatrix must have the same -data type as the fragment accumulator. - -This intrinsics is used to calculated -$(\sum_{i=0}^{K} A_{[x,i]})$ and $(\sum_{i=0}^{K} B_{[i,y]})$ from our above -equation. - ## Acknowledgments From f31db93798013881d8be613c8c2a2062cbc023e8 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 16 Jan 2024 13:40:07 -0600 Subject: [PATCH 4/5] Move to SM 6.9 --- proposals/xxxx-wave-matrix.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proposals/xxxx-wave-matrix.md b/proposals/xxxx-wave-matrix.md index 3541705f..5315c067 100644 --- a/proposals/xxxx-wave-matrix.md +++ b/proposals/xxxx-wave-matrix.md @@ -6,7 +6,7 @@ * Author(s): [Chris Bieneman](https://github.com/llvm-beanz) * Sponsor: [Greg Roth](https://github.com/pow2clk), [Tex Riddell](https://github.com/tex3d) * Status: **Under Consideration** -* Planned Version: Shader Model 6.8 +* Planned Version: Shader Model 6.9 ## Introduction @@ -14,7 +14,7 @@ This proposal adds HLSL support for a new DXIL feature WaveMatrix. The new data types and operations add support for wave cooperative matrix multiplication and accumulation. The underlying hardware and driver interfaces for this feature are -introduced in Shader Model 6.8. +introduced in Shader Model 6.9. ## Motivation @@ -25,7 +25,7 @@ SIMD hardware. These higher throughput matrix operations are required for optimal performance of many machine learning and image processing workloads. Adding support to HLSL -will enable high-performance matrix operations across Shader Model 6.8 drivers +will enable high-performance matrix operations across Shader Model 6.9 drivers when hardware support is available. ## Proposed solution From 08203e5aef646000587a1550f71ca4c41995a545 Mon Sep 17 00:00:00 2001 From: Chris Bieneman Date: Tue, 16 Jan 2024 13:47:13 -0600 Subject: [PATCH 5/5] Add DXIL changes --- proposals/xxxx-wave-matrix.md | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/proposals/xxxx-wave-matrix.md b/proposals/xxxx-wave-matrix.md index 5315c067..e7b60ba7 100644 --- a/proposals/xxxx-wave-matrix.md +++ b/proposals/xxxx-wave-matrix.md @@ -396,6 +396,52 @@ $+ Z_a * Z_b * K$ is the static zero point adjustment for both matrix $A$ and $B $Z_*$ are constant zero points values +### DXIL Changes + +### New DXIL Types + +Wave Matrix adds a new opaque object type for interacting with Wave Matrix +objects. + +``` +%dx.types.waveMatrix = type { i8* } +``` + +Wave Matrix also adds a new property type for Wave Matrix objects. + +``` +%dx.types.waveMatProps = type { i8, i8, i32, i32 } +``` +# New DXIL Opcodes + +``` +┌───────────────────────────────┬────────┬───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Name │ opcode │ IR Signature │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_Annotate │ │void @dx.op.waveMatrix_Annotate(i32, %dx.types.waveMatrix*, %dx.types.waveMatProps) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_Depth │ │i32 @dx.op.waveMatrix_Depth(i32, %dx.types.waveMatProps) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_Fill │ │void @dx.op.waveMatrix_Fill.(i32, %dx.types.waveMatrix*, ) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_LoadRawBuf │ │void @dx.op.waveMatrix_LoadRawBuf(i32, %dx.types.waveMatrix*, %dx.types.Handle, i32, i32, i8, i1) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_LoadGroupShared │ │void @dx.op.waveMatrix_LoadGroupShared.(i32, %dx.types.waveMatrix*, addrspace(3)*, i32, i32, i1) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_StoreRawBuf │ │void @dx.op.waveMatrix_StoreRawBuf(i32, %dx.types.waveMatrix*, %dx.types.Handle, i32, i32, i8, i1) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_StoreGroupShared │ │void @dx.op.waveMatrix_StoreGroupShared.(i32, %dx.types.waveMatrix*, addrspace(3)*, i32, i32, i1) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_Multiply │ │void @dx.op.waveMatrix_Multiply(i32, %dx.types.waveMatrix*, %dx.types.waveMatrix*, %dx.types.waveMatrix*) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_MultiplyAccumulate │ │void @dx.op.waveMatrix_Accumulate(i32, %dx.types.waveMatrix*, %dx.types.waveMatrix*) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_ScalarOp │ │void @dx.op.waveMatrix_ScalarOp.(i32, %dx.types.waveMatrix*, i8, ) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_SumAccumulate │ │void @dx.op.waveMatrix_Accumulate(i32, %dx.types.waveMatrix*, %dx.types.waveMatrix*) │ +├───────────────────────────────┼────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│waveMatrix_Add │ │void @dx.op.waveMatrix_Accumulate(i32, %dx.types.waveMatrix*, %dx.types.waveMatrix*) │ +└───────────────────────────────┴────────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ ## Acknowledgments