Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync develop for ROCm 6.4 #479

Merged
merged 48 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c39b14a
Bump rocm-docs-core from 1.8.3 to 1.9.0 in /docs/sphinx
dependabot[bot] Nov 25, 2024
c4e8d43
Merge pull request #465 from ROCm/dependabot/pip/docs/sphinx/rocm-doc…
cgmillette Nov 28, 2024
a455d3f
Serialize GEMM kernel runs
cgmillette May 7, 2024
5337ba8
First working interleaved 128x128 macro kernel
cgmillette Jun 27, 2024
ce0f116
Initial layout classes refactor
cgmillette Sep 9, 2024
cdd13bc
Refactor layout and traits organization
cgmillette Sep 12, 2024
e15f5b0
Remove unused file
cgmillette Sep 12, 2024
c27e4e1
Refactor layout traits
cgmillette Sep 26, 2024
f27ed38
Fixes build after layout folder refactor
cgmillette Oct 15, 2024
f3622e3
Update interleaving function
cgmillette Oct 24, 2024
2c550c3
Update is_layout_same and is_layout_orthogonal and matrix layouts logic
cgmillette Nov 9, 2024
64da221
Introduce register formats and refactor is_layout_same and is_layout_…
cgmillette Nov 11, 2024
b222868
Fixup interleaved layouts logic bugs. Add layout formats to fit all w…
cgmillette Nov 15, 2024
a5c23e4
Fix compiler unroll issue with function arg
cgmillette Nov 15, 2024
678a2d3
Deploy new mma workflow
cgmillette Nov 16, 2024
622a256
Fix include issues and io_shape test
cgmillette Nov 19, 2024
2c08b36
Add interleaved layout IOLayoutInt
cgmillette Nov 19, 2024
d89b536
Fixes for interleaved layout compatibility
cgmillette Nov 22, 2024
f6ff3e4
Add initial non-interleaved layout traits test
cgmillette Nov 25, 2024
9cc7ebc
Add DataT to Mma layout interface. Add checks for data size comparison
cgmillette Nov 25, 2024
368c6dc
Fixes f64 tests. Adds all block sizes tests.
cgmillette Nov 27, 2024
d88e353
Start implementing interleaved layout traits tests
cgmillette Nov 27, 2024
31e2f5a
Add interleaved and emulation tests
cgmillette Nov 28, 2024
ef05816
Fix build of layout unit tests
cgmillette Nov 28, 2024
5e8d2f6
Bump rocm-docs-core from 1.9.0 to 1.9.2 in /docs/sphinx (#467)
dependabot[bot] Nov 28, 2024
365b02e
Update CODEOWNERS
cgmillette Dec 2, 2024
b875471
Fix gfx11 implementation
cgmillette Dec 2, 2024
b035289
Restore perf_hgemm
cgmillette Dec 2, 2024
43c96fe
Skip tests on invalid layout condition for BlockK
cgmillette Dec 3, 2024
1aeb382
Add a softer warning for unsupported transform attempts
cgmillette Dec 3, 2024
8ee5c17
Updated changelog and set version to 1.7.0
CongMa13 Dec 2, 2024
0055868
Updated docs
CongMa13 Dec 2, 2024
656972e
Improve docs formatting
CongMa13 Dec 3, 2024
abb085f
Adjust MaxVWSelector to fit more layout constraints
cgmillette Dec 4, 2024
a420f9b
Update / correct non-interleaved layout tests
cgmillette Dec 4, 2024
e494ec1
Prevent sgemm kernel from building on unsupported targets
cgmillette Dec 4, 2024
8c75e88
Fixes: remove default Format argument to avoid usage mistakes; fix te…
cgmillette Dec 5, 2024
73fb53c
Fixup interleaved tests on gfx11
cgmillette Dec 5, 2024
d2d13e3
Minor CMake Changes (#466)
dlangbe Dec 10, 2024
fc3a273
Bump rocm-docs-core from 1.9.2 to 1.11.0 in /docs/sphinx (#474)
dependabot[bot] Dec 10, 2024
4d6fab8
Allow acc post mma xform to convert gfx11 mma acc quirk into configur…
cgmillette Dec 11, 2024
e3c61a3
Fixup MmaDim calculator
cgmillette Dec 11, 2024
61274d0
Removed WMMA_ACC_INT* formats
cgmillette Dec 17, 2024
45cf20e
removes std min reference for hipRTC
cgmillette Dec 17, 2024
d13b2f3
Update CMakeLists.txt
cgmillette Dec 17, 2024
ada4c03
Merge pull request #478 from ROCm/cgmillette-patch-2
cgmillette Dec 18, 2024
19af315
Update perf_hgemm.cpp
cgmillette Dec 18, 2024
c831ff5
Merge pull request #472 from cgmillette/interleave-dev
cgmillette Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
* @cgmillette @congma13 @bragadeesh @mkarunan @dlangbe
* @cgmillette @congma13 @bragadeesh @dlangbe
# Documentation files
docs/* @ROCm/rocm-documentation
*.md @ROCm/rocm-documentation
Expand Down
18 changes: 17 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,23 @@
Documentation for rocWMMA is available at
[https://rocm.docs.amd.com/projects/rocWMMA/en/latest](https://rocm.docs.amd.com/projects/rocWMMA/en/latest).

## (Unreleased) rocWMMA 1.6.0 for ROCm 6.3.0
## (Unreleased) rocWMMA 1.7.0 for ROCm 6.4.0

### Added

* Added interleaved layouts that enhance the performance of GEMM operations
* Added emulation test suites. These suites are lightweight and well-suited for execution on emulator platforms

### Changed

* Used GPU_TARGETS instead of AMDGPU_TARGETS in `cmakelists.txt`
* Used `--offload-compress` flag for supported compilers

### Resolved issues

* For a CMake bug workaround, set `CMAKE_NO_BUILTIN_CHRPATH` when `BUILD_OFFLOAD_COMPRESS` is unset

## rocWMMA 1.6.0 for ROCm 6.3.0

### Added

Expand Down
37 changes: 23 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ if( CMAKE_PROJECT_NAME STREQUAL "rocwmma" )
option( ROCWMMA_BUILD_TESTS "Build rocWMMA tests" ON )
option( ROCWMMA_BUILD_SAMPLES "Build rocWMMA samples" ON )
option( ROCWMMA_BUILD_ASSEMBLY "Output assembly files" OFF )
option( BUILD_OFFLOAD_COMPRESS "Build rocWMMA with offload compression" ON )
endif()

# set( AMDGPU_TARGETS "gfx908:xnack-" ) # User variable
if( CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT )
set( CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "Install path prefix, prepended onto install directories" FORCE )
endif()
Expand All @@ -77,7 +77,7 @@ include(ROCMCheckTargetIds)
include(ROCMClients)

# Versioning via rocm-cmake
set ( VERSION_STRING "1.6.0" )
set ( VERSION_STRING "1.7.0" )
rocm_setup_version( VERSION ${VERSION_STRING} )

# configure a header file to pass the CMake version settings to the source
Expand All @@ -96,28 +96,37 @@ endif()

if (ADDRESS_SANITIZER_ENABLED)
#TODO: Remove next line when rocm-cmake fix is available
set(CMAKE_NO_BUILTIN_CHRPATH ON)
rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx90a:xnack+;gfx942:xnack+" )
else()
#TODO: Remove next line when rocm-cmake fix is available
set(CMAKE_NO_BUILTIN_CHRPATH ON)
rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908;gfx90a;gfx942;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" )
endif()

# Check if offload compression is supported
include(CheckCXXCompilerFlag)
if (BUILD_OFFLOAD_COMPRESS)
check_cxx_compiler_flag("--offload-compress" CXX_COMPILER_SUPPORTS_OFFLOAD_COMPRESS)
endif()

# TODO: Remove next line when rocm-cmake fix is available
# Currently fixes linking issues with large executables
set(CMAKE_NO_BUILTIN_CHRPATH ON)

# Variable AMDGPU_TARGET must be a cached variable and must be specified before calling find_package(hip)
# This is because hip-config.cmake sets --offload-arch via AMDGPU_TARGET cached variable __after__ setting
# default cached variable AMDGPU_TARGET to DEFAULT_AMDGPU_TARGETS, where not all archs are compatible with MFMA instructions
# Variable GPU_TARGET must be a cached variable and must be specified before calling find_package(hip)
# This is because hip-config.cmake sets --offload-arch via GPU_TARGET cached variable __after__ setting
# default cached variable GPU_TARGET to DEFAULT_GPU_TARGETS, where not all archs are compatible with MFMA instructions
#
# By rule, once cached variable is set, it cannot be overridden unless we use the FORCE option
if(AMDGPU_TARGETS)
set(AMDGPU_TARGETS "${AMDGPU_TARGETS}" CACHE STRING "List of specific machine types for library to target")
if(GPU_TARGETS)
set(GPU_TARGETS "${GPU_TARGETS}" CACHE STRING "List of specific machine types for library to target")
elseif(AMDGPU_TARGETS)
set(GPU_TARGETS "${AMDGPU_TARGETS}" CACHE STRING "List of specific machine types for library to target")
message(STATUS "WARNING: AMDGPU_TARGETS use is deprecated. Use GPU_TARGETS.")
else()
set(AMDGPU_TARGETS "${DEFAULT_AMDGPU_TARGETS}" CACHE STRING "List of specific machine types for library to target")
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING "List of specific machine types for library to target")
endif()
message( VERBOSE "AMDGPU_TARGETS=${AMDGPU_TARGETS}")
message( VERBOSE "GPU_TARGETS=${GPU_TARGETS}")

find_package( hip REQUIRED )
find_package( hiprtc REQUIRED )
Expand Down
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ The test suite includes validation and benchmarking projects that focus on unit

## Requirements

rocWMMA currently supports the following AMDGPU architectures:
rocWMMA currently supports the following AMD GPU architectures:

* CDNA class GPU featuring matrix core support: gfx908, gfx90a, gfx940, gfx940, gfx942 as 'gfx9'
* CDNA class GPU featuring matrix core support: gfx908, gfx90a, gfx940, gfx941, gfx942 as 'gfx9'
* RDNA3 class GPU featuring AI acceleration support: gfx1100, gfx1101, gfx1102 as 'gfx11'

Dependencies:

* Minimum ROCm version support is 6.3.
* Minimum ROCm version support is 6.4.
* Minimum cmake version support is 3.14.
* Minimum ROCm-cmake version support is 0.8.0.
* Minimum rocBLAS version support is rocBLAS 4.0.0 for ROCm 6.0* (or ROCm packages rocblas and rocblas-dev).
Expand All @@ -47,7 +47,8 @@ For more detailed information, please refer to the [rocWMMA installation guide](

|Option|Description|Default value|
|---|---|---|
|AMDGPU_TARGETS|Build code for specific GPU target(s)|gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx1100;gfx1101;gfx1102|
|GPU_TARGETS|Build code for specific GPU target(s)|gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx942;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201|
|AMDGPU_TARGETS|(Deprecated) Build code for specific GPU target(s)|gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx942;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201|
|ROCWMMA_BUILD_TESTS|Build Tests|ON|
|ROCWMMA_BUILD_SAMPLES|Build Samples|ON|
|ROCWMMA_BUILD_DOCS|Build doxygen documentation from code|OFF|
Expand All @@ -67,7 +68,7 @@ results. Here are some configuration examples:
|Configuration|Command|
|---|---|
|Basic|`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B<build_dir> .`|
|Targeting gfx908|`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B<build_dir> . -DAMDGPU_TARGETS=gfx908:xnack-` |
|Targeting gfx908|`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B<build_dir> . -DGPU_TARGETS=gfx908:xnack-` |
|Debug build|`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B<build_dir> . -DCMAKE_BUILD_TYPE=Debug` |
|Build without rocBLAS (default on)|`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B<build_dir> . -DROCWMMA_VALIDATE_WITH_ROCBLAS=OFF -DROCWMMA_BENCHMARK_WITH_ROCBLAS=OFF` |

Expand Down
12 changes: 11 additions & 1 deletion docs/api-reference/api-reference-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ layout_t
^^^^^^^^

.. doxygenenum:: rocwmma::layout_t
:members:


rocWMMA API functions
Expand Down Expand Up @@ -315,3 +314,14 @@ Sample programs

See a sample code for calling rocWMMA functions ``load_matrix_sync``, ``store_matrix_sync``, ``fill_fragment``, and ``mma_sync`` `here <https://github.com/ROCm/rocWMMA/blob/develop/samples/simple_hgemm.cpp>`_.
For more such sample programs, refer to the `Samples directory <https://github.com/ROCm/rocWMMA/tree/develop/samples>`_.

Emulation tests
---------------

The emulation test is a smaller test suite specifically designed for emulators. It comprises a selection of test cases from the full ROCWMM test set, allowing for significantly faster execution on emulated platforms. Despite its concise nature, the emulation test supports ``smoke``, ``regression``, and ``extended`` modes.

For example, run a smoke test.

.. code-block:: bash

rtest.py --install_dir <build_dir> --emulation smoke
14 changes: 9 additions & 5 deletions docs/install/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ To install rocWMMA on SLES, use:

Once installed, rocWMMA can be used just like any other library with a C++ API.

.. note::
The prebuilt package supports the following targets: ``gfx908``; ``gfx90a``; ``gfx942``; ``gfx1100``; ``gfx1101``; ``gfx1102``; ``gfx1200``; ``gfx1201``


Once rocWMMA is installed, you can see the ``rocwmma.hpp`` header file in the ``/opt/rocm/include/rocwmma`` directory.
You must include only ``rocwmma.hpp``, ``rocwmma_coop.hpp`` and ``rocwmma_transforms.hpp`` in the user code to make calls into rocWMMA.
Don't directly include other rocWMMA files that are found in ``/opt/rocm/include/internal``.
Expand Down Expand Up @@ -90,7 +94,7 @@ Dependencies
^^^^^^^^^^^^
rocWMMA is designed to have minimal external dependencies such that it is light-weight and portable.

* Minimum ROCm version support is 6.0.
* Minimum ROCm version support is 6.4.
* Minimum cmake version support is 3.14.
* Minimum ROCm-cmake version support is 0.8.0.
* Minimum rocBLAS version support is rocBLAS 4.0.0 for ROCm 6.0* (or ROCm packages rocblas and rocblas-dev).
Expand Down Expand Up @@ -183,9 +187,9 @@ Below are the project options available to build rocWMMA library with or without
* - **Option**
- **Description**
- **Default Value**
* - AMDGPU_TARGETS
* - GPU_TARGETS
- Build code for specific GPU target(s)
- ``gfx908:xnack-``; ``gfx90a:xnack-``; ``gfx90a:xnack+``; ``gfx940``; ``gfx941``; ``gfx942``; ``gfx1100``; ``gfx1101``; ``gfx1102``
- ``gfx908``; ``gfx90a``; ``gfx942``; ``gfx1100``; ``gfx1101``; ``gfx1102``; ``gfx1200``; ``gfx1201``
* - ROCWMMA_BUILD_TESTS
- Build Tests
- ON
Expand Down Expand Up @@ -235,7 +239,7 @@ Here are some other example project configurations:
+===================================+================================================================================================================================================================+
| Basic | :code:`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B <build_dir>` |
+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Targeting gfx908 | :code:`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B <build_dir> . -DAMDGPU_TARGETS=gfx908:xnack-` |
| Targeting gfx908 | :code:`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B <build_dir> . -DGPU_TARGETS=gfx908:xnack-` |
+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Debug build | :code:`CC=/opt/rocm/bin/amdclang CXX=/opt/rocm/bin/amdclang++ cmake -B <build_dir> . -DCMAKE_BUILD_TYPE=Debug` |
+-----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
Expand Down Expand Up @@ -481,7 +485,7 @@ Build performance

Depending on the resources available to the build machine and the build configuration selected, rocWMMA build times can be on the order of an hour or more. Here are some things you can do to reduce build times:

* Target a specific GPU (e.g., ``-D AMDGPU_TARGETS=gfx908:xnack-``)
* Target a specific GPU (e.g., ``-D GPU_TARGETS=gfx908:xnack-``)
* Use lots of threads (e.g., ``-j32``)
* Select ``ROCWMMA_BUILD_ASSEMBLY=OFF``
* Select ``ROCWMMA_BUILD_DOCS=OFF``.
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rocm-docs-core==1.8.3
rocm-docs-core==1.11.0
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.8.3
rocm-docs-core==1.11.0
# via -r requirements.in
smmap==5.0.1
# via gitdb
Expand Down
15 changes: 10 additions & 5 deletions library/include/rocwmma/internal/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -146,6 +146,11 @@ static_assert(0, "Unsupported architecture");
#define ROCWMMA_ARCH_GFX94X 1
#endif

#if ROCWMMA_ARCH_HOST
#define ROCWMMA_BLOCK_DIM_16_SUPPORTED 1
#define ROCWMMA_BLOCK_DIM_32_SUPPORTED 1
#endif

#if !defined(ROCWMMA_ARCH_GFX9)
#define ROCWMMA_ARCH_GFX9 0
#endif
Expand Down Expand Up @@ -201,10 +206,10 @@ static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DI
#endif

#if ROCWMMA_ARCH_GFX12
static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE),
"rocWMMA supports only wave32 for gfx12 arch");
static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED),
"rocWMMA supports only block size of 16 for gfx12 arch");
static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE),
"rocWMMA supports only wave32 for gfx12 arch");
static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED),
"rocWMMA supports only block size of 16 for gfx12 arch");
#endif

///
Expand Down
12 changes: 11 additions & 1 deletion library/include/rocwmma/internal/coop_io_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
#ifndef ROCWMMA_COOP_IO_CONFIG_HPP
#define ROCWMMA_COOP_IO_CONFIG_HPP

#include "./layout/register_layout_transforms.hpp"
#include "coop_load.hpp"
#include "coop_store.hpp"
#include "io_layout.hpp"
Expand Down Expand Up @@ -85,6 +86,15 @@ namespace rocwmma
typename IOLayout::MatrixLayout,
IOLayout::VW>;

using PostLoadXForm = register_layout_transform<typename IOLayout::StorageLayout,
typename IOLayout::FragmentLayout>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

using PreStoreXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::StorageLayout>;

using Storer = CooperativeStore<IOShape::BlockDim,
IOShape::KDim,
DataT,
Expand Down
1 change: 0 additions & 1 deletion library/include/rocwmma/internal/coop_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#define ROCWMMA_COOP_LOAD_HPP

#include "io_traits.hpp"
#include "layout.hpp"
#include "opaque_load.hpp"
#include "types.hpp"
#include "utils.hpp"
Expand Down
1 change: 0 additions & 1 deletion library/include/rocwmma/internal/coop_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#define ROCWMMA_COOP_STORE_HPP

#include "io_traits.hpp"
#include "layout.hpp"
#include "opaque_store.hpp"
#include "types.hpp"
#include "utils.hpp"
Expand Down
32 changes: 27 additions & 5 deletions library/include/rocwmma/internal/io_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
#ifndef ROCWMMA_IO_CONFIG_HPP
#define ROCWMMA_IO_CONFIG_HPP

#include "./layout/register_layout_transforms.hpp"
#include "broadcast.hpp"
#include "coop_load.hpp"
#include "coop_store.hpp"
Expand All @@ -37,7 +38,6 @@

namespace rocwmma
{

/**
* \defgroup Rocwmma_ioconf ROCWMMA IOConfig
* @brief ROCWMMA fragment input and output configurations
Expand Down Expand Up @@ -88,6 +88,21 @@ namespace rocwmma
typename IOLayout::MatrixLayout,
IOLayout::VW>;

using PostLoadXForm = register_layout_transform<typename IOLayout::StorageLayout,
typename IOLayout::FragmentLayout>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

// Currently, only makes sense to have a post-mma transform on acc layouts
using PostMmaXForm = conditional_t<is_same_v<MatrixT, accumulator>,
register_layout_transform<typename IOLayout::MmaLayout,
typename IOLayout::FragmentLayout>,
register_layout_transform_nop>;

using PreStoreXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::StorageLayout>;

using Storer = OpaqueStore<IOShape::BlockDim,
IOShape::KDim,
DataT,
Expand All @@ -106,10 +121,17 @@ namespace rocwmma
template <uint32_t BlockM, uint32_t BlockN, uint32_t BlockK, typename DataT>
struct IOConfig<accumulator, BlockM, BlockN, BlockK, DataT, void>
{
using IOShape = IOShape<accumulator, BlockM, BlockN, BlockK>;
using IOTraits = IOTraits<IOShape::BlockDim, IOShape::KDim, DataT>;
using PackUtil = PackUtil<DataT>;
using IOShape = IOShape<accumulator, BlockM, BlockN, BlockK>;
using IOLayout = IOLayout<accumulator, IOShape::BlockDim, IOShape::KDim, DataT, void, 1u>;
using IOTraits = IOTraits<IOShape::BlockDim, IOShape::KDim, DataT>;
using PackUtil = PackUtil<DataT>;
using Broadcaster = Broadcast<DataT, IOTraits::UnpackedSize>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

using PostMmaXForm = register_layout_transform<typename IOLayout::MmaLayout,
typename IOLayout::FragmentLayout>;
};
/** @}*/

Expand Down
Loading
Loading