From 47d800ccd9449e1bbc255d64d794ae88d99b043d Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Thu, 13 Jun 2024 08:47:43 +0000 Subject: [PATCH] Release notes for cudnn-frontend 1.5.0: (#81) [New feature] With cudnn backend 9.2.0 and above, `Graph::check_support` can determine support check for runtime engines without invoking the nvrtc compiler. This allows users to check the support surface of cudnn without invoking the nvrtc compilation. [New feature] Python pip wheel now contains the necessary c++ development headers. [New feature] Sliding window attention is now supported as an attribute to the sdpa forward and bprop node. Usage: `sdpa_attributes.set_sliding_window_length(window_length)` [New feature] Bottom right aligned causal masking is now supported as an attribute to the sdpa forward and bprop node. Usage: `sdpa_attributes.use_causal_mask_bottom_right(true)` [New feature] SDPA bprop attributes can choose deterministic algorithm using the `use_deterministic_algorithm` API. [New feature] Allow users to filter candidate execution plans of graph by its shared memory usage in cudnn 9.2.0 and later. [Bug fix] A runtime error if chosen execution plan candidate is incorrectly set in the backend has been fixed. This would happen when `check_support` does not correctly filter by the workspace size. [Bug fix] selecting/deselecting by behavior and numerical notes has now been fixed and works as intended. [Debugging] A new tool for easy reproduction of a failure using the json representation of the graph can be found [here](tools/json_reproducer). [Samples] Restructured the cpp samples into categories for easier navigation. [Samples] Added a sample to showcase how different plans can be built in parallel in separate threads. [Compilation enhancement] Added a new macro `CUDNN_FRONTEND_SKIP_NLOHMANN_JSON` as compilation flag to not have nlohman::json as compilation dependency. Users lose access to certain API functions like `print`, `key`, `serialize`, `deserialzie` that depend on the library. [Enhancement] Serialization of resample operation is now supported. [Enhancement] Bug template has been added for new github issues --- CMakeLists.txt | 6 +- README.FE.1.0.md | 40 +- README.md | 32 +- docs/operations/Attention.md | 264 ++-- docs/operations/Pointwise.md | 9 + include/cudnn_frontend.h | 2 +- .../backend/backend_descriptor.h | 2 + .../backend/execution_helpers.h | 2 + include/cudnn_frontend/backend/plan_helpers.h | 60 + include/cudnn_frontend/context.h | 13 + include/cudnn_frontend/cudnn_interface.h | 30 +- include/cudnn_frontend/graph_helpers.h | 14 +- include/cudnn_frontend/graph_interface.h | 195 ++- include/cudnn_frontend/graph_properties.h | 240 ++-- include/cudnn_frontend/node/batchnorm.h | 2 + .../cudnn_frontend/node/batchnorm_inference.h | 2 + include/cudnn_frontend/node/bn_finalize.h | 2 + include/cudnn_frontend/node/conv_dgrad.h | 2 + include/cudnn_frontend/node/conv_fprop.h | 2 + include/cudnn_frontend/node/conv_wgrad.h | 2 + include/cudnn_frontend/node/dbn.h | 2 + include/cudnn_frontend/node/dbn_weight.h | 2 + include/cudnn_frontend/node/dln.h | 2 + include/cudnn_frontend/node/genstats.h | 2 + include/cudnn_frontend/node/instancenorm.h | 4 + include/cudnn_frontend/node/layernorm.h | 2 + include/cudnn_frontend/node/matmul.h | 46 + include/cudnn_frontend/node/matmul_fp8.h | 2 + include/cudnn_frontend/node/pointwise.h | 56 +- include/cudnn_frontend/node/reduction.h | 2 + include/cudnn_frontend/node/resample.h | 2 + include/cudnn_frontend/node/reshape.h | 2 + include/cudnn_frontend/node/rmsnorm.h | 4 + include/cudnn_frontend/node/rng.h | 2 + .../node/scaled_dot_product_flash_attention.h | 590 ++++++--- include/cudnn_frontend/node/sdpa_fp8.h | 2 + include/cudnn_frontend/node/sdpa_fp8_bwd.h | 2 + include/cudnn_frontend/node/softmax.h | 2 + include/cudnn_frontend/node_interface.h | 28 +- include/cudnn_frontend/plans.h | 262 ++-- include/cudnn_frontend/utils/serialize.h | 13 +- include/cudnn_frontend_ConvDesc.h | 4 + include/cudnn_frontend_Errata.h | 5 + include/cudnn_frontend_MatMulDesc.h | 4 + include/cudnn_frontend_Operation.h | 4 + include/cudnn_frontend_PointWiseDesc.h | 5 + include/cudnn_frontend_ReductionDesc.h | 5 + include/cudnn_frontend_Resample.h | 7 + include/cudnn_frontend_Rng.h | 4 + include/cudnn_frontend_Tensor.h | 8 + include/cudnn_frontend_shim.h | 13 + include/cudnn_frontend_utils.h | 15 +- pyproject.toml | 10 +- python/cudnn/__init__.py | 3 +- python/properties.cpp | 11 +- python/pycudnn.cpp | 6 +- python/pygraph/pointwise.cpp | 58 +- python/pygraph/pygraph.cpp | 40 +- python/pygraph/pygraph.h | 19 +- python/pygraph/sdpa.cpp | 30 + samples/CMakeLists.txt | 43 +- samples/README.md | 112 ++ samples/cpp/{ => convolution}/dgrads.cpp | 2 +- samples/cpp/convolution/fp8_fprop.cpp | 131 ++ .../fprop.cpp} | 203 ++-- samples/cpp/convolution/int8_fprop.cpp | 100 ++ samples/cpp/{ => convolution}/wgrads.cpp | 2 +- samples/cpp/matmul/fp8_matmul.cpp | 128 ++ samples/cpp/matmul/int8_matmul.cpp | 114 ++ samples/cpp/{ => matmul}/matmuls.cpp | 330 +---- samples/cpp/matmul/mixed_matmul.cpp | 106 ++ samples/cpp/mha.cpp | 1057 ----------------- samples/cpp/{ => misc}/autotuning.cpp | 11 +- samples/cpp/misc/parallel_compilation.cpp | 152 +++ samples/cpp/{ => misc}/pointwise.cpp | 2 +- samples/cpp/{ => misc}/resample.cpp | 2 +- samples/cpp/{ => misc}/serialization.cpp | 2 +- samples/cpp/{ => norm}/batchnorm.cpp | 2 +- samples/cpp/{ => norm}/layernorm.cpp | 2 +- samples/cpp/{ => norm}/rmsnorm.cpp | 2 +- samples/cpp/sdpa/fp16_bwd.cpp | 274 +++++ samples/cpp/sdpa/fp16_cached.cpp | 175 +++ samples/cpp/sdpa/fp16_fwd.cpp | 219 ++++ samples/cpp/sdpa/fp8_bwd.cpp | 391 ++++++ samples/cpp/sdpa/fp8_fwd.cpp | 155 +++ samples/legacy_samples/fusion_sample.cpp | 4 +- samples/python/00_introduction.ipynb | 41 +- samples/python/01_matmul_bias.ipynb | 23 +- .../python/02_sdpa_graph_serialization.ipynb | 92 +- .../python/03_mixed_precision_matmul.ipynb | 30 +- .../50_scaled_dot_product_attention.ipynb | 26 +- ...caled_dot_product_attention_backward.ipynb | 38 +- samples/utils/helpers.h | 30 + test/python_fe/conftest.py | 5 + test/python_fe/test_apply_rope.py | 5 +- test/python_fe/test_batchnorm.py | 9 +- test/python_fe/test_conv_bias.py | 56 +- test/python_fe/test_conv_genstats.py | 4 +- test/python_fe/test_conv_reduction.py | 5 +- test/python_fe/test_instancenorm.py | 7 +- test/python_fe/test_layernorm.py | 5 +- test/python_fe/test_matmul_bias_relu.py | 12 +- test/python_fe/test_mhas.py | 641 ++++++---- test/python_fe/test_rmsnorm.py | 8 +- test/python_fe/test_silu_and_mul.py | 238 ++++ test/python_fe/test_utils.py | 2 +- test/python_fe/test_wgrads.py | 5 +- test/unit_tests/serialize.cpp | 110 +- test/unit_tests/validate.cpp | 6 +- tools/json_reproducer/README.md | 22 + tools/json_reproducer/json_parser.py | 42 + tools/json_reproducer/jsons/graph0.json | 92 ++ 112 files changed, 5033 insertions(+), 2443 deletions(-) create mode 100644 include/cudnn_frontend/backend/plan_helpers.h rename samples/cpp/{ => convolution}/dgrads.cpp (99%) create mode 100644 samples/cpp/convolution/fp8_fprop.cpp rename samples/cpp/{convolutions.cpp => convolution/fprop.cpp} (69%) create mode 100644 samples/cpp/convolution/int8_fprop.cpp rename samples/cpp/{ => convolution}/wgrads.cpp (99%) create mode 100644 samples/cpp/matmul/fp8_matmul.cpp create mode 100644 samples/cpp/matmul/int8_matmul.cpp rename samples/cpp/{ => matmul}/matmuls.cpp (64%) create mode 100644 samples/cpp/matmul/mixed_matmul.cpp delete mode 100644 samples/cpp/mha.cpp rename samples/cpp/{ => misc}/autotuning.cpp (95%) create mode 100644 samples/cpp/misc/parallel_compilation.cpp rename samples/cpp/{ => misc}/pointwise.cpp (99%) rename samples/cpp/{ => misc}/resample.cpp (99%) rename samples/cpp/{ => misc}/serialization.cpp (99%) rename samples/cpp/{ => norm}/batchnorm.cpp (99%) rename samples/cpp/{ => norm}/layernorm.cpp (99%) rename samples/cpp/{ => norm}/rmsnorm.cpp (99%) create mode 100644 samples/cpp/sdpa/fp16_bwd.cpp create mode 100644 samples/cpp/sdpa/fp16_cached.cpp create mode 100644 samples/cpp/sdpa/fp16_fwd.cpp create mode 100644 samples/cpp/sdpa/fp8_bwd.cpp create mode 100644 samples/cpp/sdpa/fp8_fwd.cpp create mode 100644 test/python_fe/test_silu_and_mul.py create mode 100644 tools/json_reproducer/README.md create mode 100644 tools/json_reproducer/json_parser.py create mode 100644 tools/json_reproducer/jsons/graph0.json diff --git a/CMakeLists.txt b/CMakeLists.txt index e64071c..f52185f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,8 @@ cmake_minimum_required(VERSION 3.17) -project(cudnn_frontend VERSION 1.4.0) +project(cudnn_frontend VERSION 1.5.0) -option(CUDNN_FRONTEND_SKIP_NLOHMANN_JSON "Defines whether FE should not include nlohmann/json.hpp." OFF) +option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) option(CUDNN_FRONTEND_BUILD_UNIT_TESTS "Defines if unittests are built or not." ON) @@ -18,7 +18,7 @@ add_library(cudnn_frontend INTERFACE) target_compile_definitions( cudnn_frontend INTERFACE - $<$:CUDNN_FRONTEND_SKIP_NLOHMANN_JSON> + $<$:CUDNN_FRONTEND_SKIP_JSON_LIB> ) target_include_directories( diff --git a/README.FE.1.0.md b/README.FE.1.0.md index 44fde88..bcf9966 100644 --- a/README.FE.1.0.md +++ b/README.FE.1.0.md @@ -12,6 +12,11 @@ FE v1.0 API is aimed to extend functionality and usage exposed by the [cuDNN C backend API](https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnn-backend-api). Both C++ and python APIs are provided, and both have functional parity. For a general introduction to FE, please start with README.md. +In the frontend v1 API, you can describe multiple operations that form subgraphs through a persistent cudnn_frontend::graph::Graph object. Unlike the frontend v0.x API, you don't have to worry about specifying shapes and sizes of the intermediate virtual tensors. The frontend v1 API extends the groundwork of earlier versions and introduces a new set of APIs to further simplify the workflow. + +Additionally, the frontend v1 API provides Python bindings to all API. Refer to samples/cpp and samples/python for more details on its usage. +With the release of v1, we are bumping up the minimum supported cuDNN version to 8.5.0. + ## Workflow The steps involved in building and running a cudnn graph are as follows: 1. Create a cudnn graph and specify the global properties. The global properties like compute precision and input/output data type help infer properties that are not explicitly mentioned. @@ -20,10 +25,10 @@ The steps involved in building and running a cudnn graph are as follows: 4. Validate the operation graph. This step makes sure the graph is well built and does not have hanging tensors or node. 5. Build the cudnn operation graph. This step lowers the graph into cudnn dialect. 6. Create the execution plan, based on the heuristics type of your choice. -7. [Optional] Check support of the operation graph. +7. Check support of the operation graph. 8. [Optional] Filter out the plans by your custom criteria (Optional). 9. Build (one or all) the execution plans. -10. [Optional] Run autotuning on the filter plan (Optional). +10. [Optional] Run autotuning on the filtered plan (Optional). 11. Execute the graph with the relevant data pointers. ## APIs @@ -48,7 +53,7 @@ FE v1.0 API follows a functional style of building a graph. Operations take in i | [Scale dot product attention FP8](docs/operations/Attention.md) | sdpa_fp8
SDPA_fp8_attributes | sdpa_fp8 | | [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward
SDPA_fp8_backward_attributes | sdpa_fp8_backward | -### Create Graph +### Creating the Graph Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations. Optional graph level attributes can be set on the object: @@ -71,14 +76,14 @@ Tensor attributes is a lightweight structure with setters for each attribute. - `cudnn_frontend::graph::Tensor_attributes& set_reordering_type(cudnn_frontend::TensorReordering_t)` - `cudnn_frontend::graph::Tensor_attributes& set_name(std::string&)` -### Define Operations +### Defining Operations Operations take in mandatory input tensor via positional arguments. Optional input tensors are provided using corresponding setters in operation attributes. Operations return an ordered array of output tensors. Any optional outputs if not present will have their shared pointers pointing to `std::nullptr`. Please looks at [operations](#Operations) section for more details. -### Validate graph +### Validating the Graph Validate API ensures API usage is sound, checks against dangling tensors, etc. Internally, any unspecified properties like dimensions, strides, etc are inferred. @@ -86,21 +91,21 @@ Internally, any unspecified properties like dimensions, strides, etc are inferre cudnn_frontend::error_t cudnn_frontend::graph::Graph::validate() ``` -### Build cudnn backend graph +### Building the Backend Graph This method creates cudnn backend descriptors for all constituents of the graph. ``` cudnn_frontend::error_t cudnn_frontend::graph::Graph::build_operation_graph(cudnnHandle_t handle) ``` -### Create Execution plans +### Creating the Execution Plan This method internally queries the heuristics for engine configs for the given heuristics modes. ``` cudnn_frontend::error_t cudnn_frontend::graph::Graph::get_execution_plans(std::vector) ``` -### Get execution plan count +### Getting the Execution Plan Count This method returns the number of execution plans returned by cudnn heuristics. Each plan gets an index from 0 to #plans-1, with 0 having top priority. ``` @@ -108,16 +113,16 @@ cudnn_frontend::int64_t cudnn_frontend::Graph::get_execution_plan_count() const; ``` -### Check graph support +### Checking Graph Support This method guarantees that executing the graph using plans queried will succeed. ``` cudnn_frontend::error_t cudnn_frontend::graph::Graph::check_support(cudnnHandle_t h); ``` -### Build plans +### Building the Execution Plan -This function builds execution plans queried with `create_execution_plan(...)`` API. +This function builds execution plans queried with `create_execution_plan(...)` API. There are two flavours of this API: @@ -140,10 +145,7 @@ cudnn_frontend::Graph::build_plan_at_index( int64_t plan_index ); ``` - - - -### Filter plans (optional) +### Filtering Plans (Optional) Users can filter plans on numerical, behavioral notes, or plans that do not provide desired functional correctness. ``` @@ -155,15 +157,15 @@ cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_behavior_no cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_workspace_greater_than(int64_t const workspace); ``` -### Autotune +### Autotuning Autotuning provides a way to execute different execution plans for a given graph and measure their relative performance under run time conditions. This generally helps validate and improve upon the results provided by the heuristics. Please refer to [samples](samples/cpp/autotuning.cpp) -### Execute -Executing graph requires device pointers to all input output tensors and a user allocated device workspace pointer. +### Executing the Graph +Executing the graph requires device pointers to all input output tensors and a user allocated device workspace pointer. -Two flavours of execute exists, corresponding to `build_plans(...)`` API. +Two flavours of execute exists, corresponding to `build_plans(...)` API. This API already has a candidate execution plan set. Candidate execution plan get internally set either: - if build_policy_t::HEURISTIC_CHOICE is used, or diff --git a/README.md b/README.md index 7bd9021..dfd88a2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,9 @@ In FE v1.0 API, users can describe multiple operations that form subgraph throug Additionally, FE v1.0 API provides python bindings to all API through pybind11. It is recommended that new users of cuDNN start with the frontend v1.0 API. See `samples/cpp` and `samples/python` for more details on its usage. ## Usage -In order to include the entire library, include the cudnn_frontend header file `include/cudnn_frontend.h` into your compilation unit. +For c++ users, in order to include the entire library, include the cudnn_frontend header file `include/cudnn_frontend.h` into your compilation unit. + +For Python users, run `import cudnn` ## Build: @@ -31,18 +33,23 @@ cudnn can be installed from Minimum python version needed 3.6 The python binding compilation requires development package which can be installed by running `apt-get install python-dev`. -To run the python samples, additionally, you will need the following python packages: -- pytest -- torch -- jupyter - +To run the Python samples, you will need the dependencies mentioned in `requirements.txt`. This can be be installed by running: +`pip install -r requirements.txt` ### Python API +#### pip wheel installation + +Download the pip wheel corresponding to your python installation. + +``` +pip install nvidia_cudnn_frontend +``` + #### Source installation: Install FE python API by running: ``` -pip install git+https://github.com/NVIDIA/cudnn-frontend.git +pip install -v git+https://github.com/NVIDIA/cudnn-frontend.git ``` Above command picks cuda and cudnn from default system paths. @@ -50,14 +57,6 @@ Above command picks cuda and cudnn from default system paths. To provide a custom CUDA installation path, use environment variable: `CUDAToolkit_ROOT`. To provide a custom CUDNN installation path, use environment variable: `CUDNN_PATH`. -#### pip wheel installation - -Download the pip wheel corresponding to your python installation. - -``` -pip install nvidia_cudnn_frontend-1.2.0-*.whl -``` - #### Checking the installation To test whether installation is successful, run: ``` @@ -66,7 +65,6 @@ pytest test/python_fe NOTE: Only v1.0 API is exposed via python bindings. - ### C++ API C++ API is header only library. @@ -74,7 +72,7 @@ C++ API is header only library. The root CMakeLists.txt can be used as reference to include the cudnn_frontend in your project's build system. #### Building samples -The following compilation steps are only required for building the samples and/or python bindings. +The following compilation steps are only required for building the samples. Provide CUDA installation path according to: https://cmake.org/cmake/help/latest/module/FindCUDAToolkit.html diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index 16263ca..9758a65 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -1,17 +1,13 @@ ## Table of Contents -1. [Scaled Dot Product Attention](#scaled-dot-product-attention) -2. [Scaled Dot Product Attention Backward](#scaled-dot-product-attention-backward) -3. [Scaled Dot Product Attention FP8](#scaled-dot-product-attention-fp8) -4. [Scaled Dot Product Attention Backward FP8](#scaled-dot-product-attention-backward-fp8) -5. Appendices - - [Tensor Layouts](#appendix-a) - - [Workspace limits and Performance](#appendix-b) - - [RNG dump](#appendix-c) -6. [Miscellaneous](#miscellaneous) +1. [Scaled Dot Product Attention FP16/BF16 Forward](#scaled-dot-product-attention-fp16bf16-forward) +2. [Scaled Dot Product Attention FP16/BF16 Backward](#scaled-dot-product-attention-fp16bf16-backward) +3. [Scaled Dot Product Attention FP8 Forward](#scaled-dot-product-attention-fp8-forward) +4. [Scaled Dot Product Attention FP8 Backward](#scaled-dot-product-attention-fp8-backward) +5. [Supported Tensor Layouts](#supported-tensor-layouts) -### Scaled Dot Product Attention +### Scaled Dot Product Attention FP16/BF16 Forward -This operation computes the scaled dot product attention, as +This operation computes the scaled dot product attention (SDPA), as $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$ @@ -19,42 +15,53 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2 - Python sample: [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb) -- C++ sample: [samples/cpp/mha.cpp](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/cpp/mha.cpp) +- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) - Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) #### Configurable Options: - Attention scale (`attn_scale`): Applies a scaling factor to attention scores before the softmax, such as $\frac{1}{\sqrt{\text{d}}}$. Set to 1.0 by default. -- Bias mask: Applies an additive bias mask to attention scores. Users must pass a bias tensor as specified in the tensors section below. +- Bias mask: Applies an additive bias mask to attention scores. Users must pass a bias tensor as specified in the tensors section below. The dimensions that are passed as 1 will apply a broadcasted mask over attention scores. - Alibi mask: Attention with Linear Biases (ALiBi) is an additive mask applied to the attention scores as described in the paper [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409). - Padding mask: Also called variable sequence length, this option masks out padded time steps to ignore them in computation. Users must pass a per-batch sequence length as specified in the tensors section below. - Causal mask: Fills the upper triangular matrix of attention scores with negative infinity. +- Sliding window mask: Allows computation of attention scores from \(pos-sliding_window_length, pos\] for every position `pos`. Fills rest of the entries in the matrix with negative infinity. - Dropout: Randomly zeros some of the attention weights after the softmax as a form of regularization. Users can configure dropout in two ways: - To use the more performant Philox RNG dropout implementation, users must provide: - An RNG seed, passed as a cudnn tensor. - An RNG offset, passed as a cudnn tensor. - A float representing the dropout probability, which is the probability that any given weight is set to zero. + - (Debug only) Output RNG dump generated by the Philox RNG, passed as a cuDNN tensor. - To use an user-provided dropout mask, users must provide: - - `dropout mask` that matches the attention weights' dimensions, indicating which weights to drop. + - `dropout mask` that matches the attention weights' dimensions, indicating which weights to drop. The dimensions that are passed as 1 will apply a broadcasted dropout mask. - `dropout scale` used to adjust the scale of the remaining weights accordingly, such as $1 / (1 - \text{dropout probability})$. -- Ragged tensor: allows the query, key, value, and output tensor to be [ragged tensors](https://www.tensorflow.org/guide/ragged_tensor), which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the `Tensor_attributes.set_ragged_offset()` method as specified in the tensors section below. +- Packed layout: With packed layout, the query, key, value, and output tensor should be [ragged tensors](https://www.tensorflow.org/guide/ragged_tensor), which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the `Tensor_attributes.set_ragged_offset()` method. the ragged offset tensor must be a tensor of size $(B + 1, 1, 1, 1)$ that contains the nested tensor's offset in terms of number of elements (not bytes). The last value of the offset tensor specifies the offset of the past-the-end element of the ragged tensor. See Appendix A for more information on the supported layouts. + +##### Input Tensors: + +| Tensor Name | Device | Data Type | Dimensions | +|-------------------------------------|------------|----------------|----------------------------------------------------------------------------------------------------------------| +| Q | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ | +| K | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$ | +| V | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$ | +| (Bias mask) Bias Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Padding mask) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Padding mask) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Custom Dropout Mask) Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ | +| (Packed Layout) Ragged Offset | GPU | INT32 | $(B + 1, 1, 1, 1)$ | -#### Tensors: +##### Output Tensors -- Query tensor should have dimensions $(B, H_{q}, S_{q}, D_{qk})$ with input/output datatype. -- Key tensor should have dimensions $(B, H_{k}, S_{kv}, D_{qk})$ with input/output datatype. -- Value tensor should have dimensions $(B, H_{v}, S_{kv}, D_{v})$ with input/output datatype. -- Output tensor should have dimensions $(B, H_{q}, S_{q}, D_{v})$ with input/output datatype. -- (Optional) When `is_inference` is false, the stats tensor should have dimensions $(B, H_{q}, S_{q}, 1)$ with float32 datatype. -- (Optional) When bias mask is enabled, the bias tensor has dimensions $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ with input/output datatype. -The dimensions that are passed as 1 will apply a broadcasted mask over attention scores. -- (Optional) When padding mask is enabled, the sequence length q, and sequence length kv tensors should have shape $(B, 1, 1, 1)$ with int32 datatype. -- (Optional) When philox RNG dropout mask is enabled, the RNG seed and offset tensors should have size $(1, 1, 1, 1)$ with int32 or int64 datatype as either a CPU or GPU tensor. -- (Optional) When a user provided dropout mask is enabled, a dropout mask tensor should have shape $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ with input/output datatype. -The dimensions that are passed as 1 will apply a broadcasted mask over attention weights. -- (Optional) When query, key, value, and output tensors are ragged tensors, the ragged offset tensor must be a tensor of size $(B + 1, 1, 1, 1)$ that contains the nested tensor's offset in terms of number of elements (not bytes). The last value of the offset tensor specifies the offset of the past-the-end element of the ragged tensor. +| Tensor Name | Device | Data Type | Dimensions | +|-------------------------------------|------------|--------------|------------------------------| +| O | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{v})$ | +| Stats (training only) | GPU | FP32 | $(B, H_{q}, S_{q}, 1)$ | +| (Philoc RNG Dropout) RNG Dump | GPU | FP32 | $(B, H_{q}, S_{q}, S_{kv})$ | Where, @@ -121,11 +128,18 @@ set_seq_len_kv(std::shared_ptr value); SDPA_attributes& set_causal_mask(bool const value); -SDPA_attributes& +SDPA_attributes & +set_sliding_window_length(int const value); + +SDPA_attributes & set_dropout(float const probability, std::shared_ptr seed, std::shared_ptr offset); +// for debugging dropout mask +SDPA_attributes& +set_rng_dump(std::shared_ptr value); + SDPA_attributes& set_dropout(std::shared_ptr mask, std::shared_ptr scale); @@ -149,7 +163,9 @@ Args: seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -158,13 +174,13 @@ Returns: stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. ``` -### Scaled Dot Product Attention Backward +### Scaled Dot Product Attention FP16/BF16 Backward -This operation computes gradient tensors for scaled dot product attention using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. +This operation computes gradient tensors for scaled dot product attention (SDPA) using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. - Python sample: [samples/python/51_scaled_dot_product_attention_backward.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/51_scaled_dot_product_attention_backward.ipynb) -- C++ sample: [samples/cpp/mha.cpp](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/cpp/mha.cpp) +- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) - Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) @@ -176,6 +192,20 @@ All the options mentioned in the forward operation, including ragged tensors and All the tensor requirements described in the forward operation are applicable in the backward operation as well. The gradient tensors for query, key, value, output, and bias should have the same properties as their non-gradient counterparts. +##### Input Tensors: + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|----------------|----------------------------| +| dO | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{v})$ | + +##### Output Tensors + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|--------------|------------------------------| +| dQ | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ | +| dK | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$ | +| dV | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$ | + #### Limitations: All the limitations mentioned in the forward operation are applicable in the backward operation as well. @@ -223,16 +253,26 @@ set_seq_len_kv(std::shared_ptr value); SDPA_backward_attributes& set_causal_mask(bool const value); +SDPA_backward_attributes & +set_sliding_window_length(int const value); + SDPA_backward_attributes& set_dropout(float const probability, std::shared_ptr seed, std::shared_ptr offset); +// for debugging dropout mask +SDPA_backward_attributes& +set_rng_dump(std::shared_ptr value); + SDPA_backward_attributes& set_dropout(std::shared_ptr mask, std::shared_ptr scale, std::shared_ptr scale_inv); +SDPA_backward_attributes& +set_deterministic_algorithm(bool const value); + SDPA_backward_attributes& set_compute_data_type(DataType_t const value); ``` @@ -255,7 +295,12 @@ Args: seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. - dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. + sliding_window_length (Optional[int]): The length of sliding window. Default is None. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], + Tuple[mask: cudnn_tensor, scale: cudnn_tensor, scale_inv: cudnn_tensor]]]): + Whether to do dropout. Default is None. + rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -265,9 +310,9 @@ Returns: dV (cudnn_tensor): The value gradient data. ``` -### Scaled Dot Product Attention FP8 +### Scaled Dot Product Attention FP8 Forward -This operation computes the scaled dot product attention in the FP8 (8-bit floating point) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation. +This operation computes the scaled dot product attention (SDPA) in the 8-bit floating point (FP8) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation. The FP8 datatype consists of two encodings: - `FP8_E4M3` (1 sign bit, 4 exponent bits, and 3 mantissa bits) @@ -283,6 +328,8 @@ The suggested value for the descale factor is the reciprocal of the scale factor Since scaling and descaling are critical for convergence with FP8 datatype, users are required to pass scaling and descaling input tensors, as well as amax output tensors. +- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) + #### Configurable Options The current FP8 support is a subset of the options supported in FP16 and BF16 support. We are actively working on expanding the support for FP8. @@ -401,9 +448,11 @@ Returns: amax_o (cudnn_tensor): The absolute maximum of output tensor. ``` -### Scaled Dot Product Attention Backward FP8 +### Scaled Dot Product Attention FP8 Backward + +This operation computes the gradients for scaled dot product attention (SDPA) 8-bit floating point (FP8) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. -This operation computes the gradients for scaled dot product attention in the FP8 (8-bit floating point) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. +- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) #### Configurable Options: @@ -432,6 +481,7 @@ $dK = QdP$ | V | GPU | E4M3 or E5M2 | $(B, H_{v}, S_{kv}, D_{v})$ | | O | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{v})$ | | dO | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{v})$ | +| Stats | GPU | FP32 | $(B, H_{q}, S_{q}, 1)$ | | Descale Q | GPU | FP32 | $(1, 1, 1, 1)$ | | Descale K | GPU | FP32 | $(1, 1, 1, 1)$ | | Descale V | GPU | FP32 | $(1, 1, 1, 1)$ | @@ -499,7 +549,7 @@ Graph::sdpa_fp8_backward(std::shared_ptr q, The `options` parameter of type `SDPA_fp8_backward_attributes` is used to control the attributes of the forward operation, as detailed below: -``` +```cpp SDPA_fp8_backward_attributes& set_attn_scale(std::shared_ptr value); @@ -546,78 +596,66 @@ Returns: amax_dP (cudnn_tensor): The absolute maximum of dP tensor. ``` -### Appendix A -Tensor Layouts: -Q, K, V, O and corresponding gradients layout support. cuDNN API expresses the layout of tensors based on strides. - -For example, let Q have dimensions = [5, 7, 4, 3], and strides = [84, 12, 3, 1] -An element at index [i, j, k, l] can be accessed at the position of Q_ptr + i * 84 + j * 12 + k * 3 + l * 1 - -Notice how the strides are multiplied to the indices to get the position of all elements. -Below we will go through the standard usage of the attention tensors and how they can be expressed in cuDNN. - - 1. Q, K, V are different matrices with strided layout - This is the basic case where the user can specify dims and strides for each of Q, K and V and it works as the example given above. - The only limitation is that stride corresponding to the hidden dimension per head (d, last dim in Q) needs to be 1. - - 2. Q, K, V are interleaved - This is a special case of (1) and can be described in a strided layout as well. - For example, Q, K and V can be a single matrix of dims (batch (b), number_of_heads (h), sequence_length (s), 3, hidden_dim_per_head(d)) - Strides of Q can be defined as [h * s * 3 * d, s * 3 * d, 3 * d, 1] - Notice how the 3 is multiplied to the strides corresponding to b, h and s because of the interleaving. - - 3. There are some special cases when all tokens are not valid and Q, K, V can be in special layouts - Let Q tensor have two sequences (i.e batch = 2, number_of_heads = 1) with max_seq_len = 8 and actual_seq_len = [2, 3] - Consider two tokens "aa" & "bbb". - - Fully padded layout - - aa000000 - bbb00000 - Dims = [b=2, h=1, s=8, d=64] - Strides = [512, 512, 64, 1] - - CUDNN gets indication of the actual sequence lengths using the seq_len_q and the seq_len_kv and cuts the computation at these values. Please enable use_padding_mask also for this case. CUDNN reads the data based on the strides. - - - Fully packed layout - aabbb000 - 00000000 - Dims = [b=2, h=1, s=8, d=64] - Strides = [512, 512, 64, 1] - - The strides remain the same but they are incorrect as the second batch begins at 64*2. Therefore, we have an API called "ragged_offset" which is a b+1 size tensor telling where each batch begins. The b+1 element is where the last batch ends. - Users can set .set_ragged_offset() - For this example ragged_offset = [0, 128, 320] - Actual sequence length still have to be provided with padding mask. - - - Valid tokens in a batch are packed together - aa00bbb0 - 00000000 - - User just needs to update the ragged offset to = [0, 256, 448] - - - Valid tokens are not packed together - a0abbb00 - bb000000 - - Ragged offset is insufficient to represent this. This case is NOT supported. - -### Appendix B -Workspace limit: -Scaled Dot Product Attention Backward improves performance by using an optional dP workspace tensor. This tensor's memory consumption increases quadratically with the sequence length. The following describes the behavior of the `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT` environment variable, which allows the user to change the GPU memory limit for this workspace tensor: - - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = unset` - The optimization will utilize workspace memory until reaching the default limit of 256MB. - - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = -1` - Workspace optimization is always enabled, regardless of memory usage. - - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = 0` - Workspace optimization is always disabled, avoiding the additional memory usage. - - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = n` - Allows workspace optimization up to a user-defined limit of n bytes, accommodating systems with varying GPU memory capacities. - -### Appendix C -To dump the dropout mask generated by the Philox RNG dropout implementation for debugging purposes, users can use the `rng_dump` option. This option requires users to pass a tensor of dimensions $(B, H_{q}, S_{q}, S_{kv})$ - -### Miscellaneous -- FE provides shadow enums which help avoid users to workaround having different enums for different cudnn versions. -- The cudnn backend enums are changed as follows: - - `cudnnBackend` -> `cudnn_frontend::` - - `cudnn` -> `cudnn_frontend::` \ No newline at end of file +### Supported Tensor Layouts + +cuDNN API expresses the layout of $Q$, $K$, $V$, $O$ tensors corresponding gradients based on strides. + +For example, let $Q$ have dimensions = $[5, 7, 4, 3]$, and strides = $[84, 12, 3, 1]$. An element at index $[i, j, k, l]$ can be accessed at the position of $Q_{ptr} + i * 84 + j * 12 + k * 3 + l * 1$ +. Notice how the strides are multiplied to the indices to get the position of all elements. + +Below we will go through the standard usage of the attention tensors and how they can be expressed in cuDNN.\ +Using the notation below:\ +$B$ is the batch size\ +$H_{q}$ is the number of query heads\ +$H_{k}$ is the number of key heads\ +$H_{v}$ is the number of value heads\ +$S_{q}$ is the sequence length of the query\ +$S_{kv}$ is the sequence length of the key and value\ +$D_{qk}$ is the embedding dimension per head of query and key\ +$D_{v}$ is the embedding dimension per head of value + +- Case 1: $Q$, $K$, $V$, $O$ are tensors in dense non-overlapping memory\ +This is the basic case where the user can specify dims and strides for each of $Q$, $K$, $V$, $O$ in any stride order. The only limitation is that the stride of the last dimension, embedding dimension per head $D_{qk}$ and $D_v$, be 1.\ +For instance for $Q$ with dimensions = $[B, H_q, S_q, D_{qk}]$, cuDNN support includes (but is not limited to): + - stride = $[S_q \times H_q \times D_{qk}, D_{qk}, H_q \times D_{qk}, 1]$ aka. BSHD layout + - stride = $[H_q \times D_{qk}, D_{qk}, B \times H_q \times D_{qk}, 1]$ aka. SBHD layout + +- Case 2: $Q$, $K$, $V$ are are tensors in dense interleaved layout\ +In some cases, users may need to interleave $Q$, $K$, $V$ tensors together to simplify the matrix multiplication preceding the scaled-dot-product operation. For instance, users can allocate a single tensor of size = $3 \times B \times H \times S \times D$, specify the $Q$, $K$, $V$ dimensions = $[B, H, S, D]$, and cuDNN support includes (but is not limited to): + - stride = $[S \times 3 \times H \times D, D, 3 \times H \times D, 1]$ aka. BS3HD \ + with $QKV$ variant pack pointers offset as\ + $Q_{ptr}$ = $Storage_{ptr}$\ + $K_{ptr}$ = $Storage_{ptr} + 1 \times H \times D$\ + $V_{ptr}$ = $Storage_{ptr} + 2 \times H \times D$ + - stride = $[H \times 3 \times D, 3 \times D, B \times H \times 3 \times D, 1]$ aka. SBH3D \ + with $QKV$ variant pack pointers offset as\ + $Q_{ptr}$ = $Storage_{ptr}$\ + $K_{ptr}$ = $Storage_{ptr} + 1 \times D$\ + $V_{ptr}$ = $Storage_{ptr} + 2 \times D$ + +- Case 3: $Q$, $K$, $V$ are are tensors where not all tokens are valid\ +Consider Q tensor with two batches ($B$ = 2) of sequences of different lengths ["aa", "bbb"]. Let maximum sequence length $S$ = 8, and number of heads $H = 1$. In this case, users should indicate the actual sequence lengths for each batch using the sequence length tensor `seq_len = [2, 3]`, and pass it to the SDPA node using `set_seq_len_q()` and `set_seq_len_kv()`. Note that every element in the sequence length tensor should always be smaller than the maximum sequence length $S$.\ +\ +cuDNN layout support for variable sequence length includes (but is not limited to): + - Fully padded layout\ + `Q[b=0] = aa000000`\ + `Q[b=1] = bbb00000`\ + dimension = $[B=2, H=1, S=8, D=64]$\ + stride = $[SHD=512, D=64, HD=64, 1]$\ + \ + cuDNN reads the data based on the strides. + + - Fully packed layout aka. THD, where T = sum(seq_len)\ + `Q = aabbb000`\ + dimension = $[B=2, H=1, S=8, D=64]$\ + stride = $[SHD=512, D=64, HD=64, 1]$\ + \ + The strides remain the same but they are incorrect as the second batch begins at 64*2. Therefore, users must set **ragged_offset** tensor using `.set_ragged_offset()` api, which is a $B + 1$ sized integer tensor telling where each batch begins. The b+1 element is where the last batch ends. For this case, ragged_offset should be `[0, 2 * H * D, (2+3) * H * D] = [0, 128, 320]` + + - Valid tokens in a batch are packed together\ + `Q = aa00bbb0`\ + For this case, ragged offset to `[0, 4 * H * D, (4+3) * H * D] = [0, 256, 448]` + + - Valid tokens are not packed together\ + `Q = a0abbb00bb000000`\ + Ragged offset is insufficient to represent this. This case is NOT supported. diff --git a/docs/operations/Pointwise.md b/docs/operations/Pointwise.md index 21e7063..c98cbdd 100644 --- a/docs/operations/Pointwise.md +++ b/docs/operations/Pointwise.md @@ -33,6 +33,15 @@ set_mode(PointwiseMode_t) Pointwise_attributes& set_axis(int64_t) +Pointwise_attributes& +set_relu_lower_clip(float) + +Pointwise_attributes& +set_relu_upper_clip(float) + +Pointwise_attributes& +set_relu_lower_clip_slope(float) + Pointwise_attributes& set_name(std::string const&) diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index 1fe14b0..114c8a9 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -124,7 +124,7 @@ #include "cudnn_frontend/utils/serialize.h" #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 4 +#define CUDNN_FRONTEND_MINOR_VERSION 5 #define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/include/cudnn_frontend/backend/backend_descriptor.h b/include/cudnn_frontend/backend/backend_descriptor.h index 410eccd..dc7ad25 100644 --- a/include/cudnn_frontend/backend/backend_descriptor.h +++ b/include/cudnn_frontend/backend/backend_descriptor.h @@ -1,3 +1,5 @@ +#pragma once + #include #include "cudnn.h" diff --git a/include/cudnn_frontend/backend/execution_helpers.h b/include/cudnn_frontend/backend/execution_helpers.h index e52abc6..334ffde 100644 --- a/include/cudnn_frontend/backend/execution_helpers.h +++ b/include/cudnn_frontend/backend/execution_helpers.h @@ -1,3 +1,5 @@ +#pragma once + #include #include "cudnn.h" diff --git a/include/cudnn_frontend/backend/plan_helpers.h b/include/cudnn_frontend/backend/plan_helpers.h new file mode 100644 index 0000000..1fa458d --- /dev/null +++ b/include/cudnn_frontend/backend/plan_helpers.h @@ -0,0 +1,60 @@ +#pragma once + +#include + +#include "cudnn.h" + +#include "backend_descriptor.h" + +namespace cudnn_frontend::detail { +/** + * @brief Creates a CUDNN backend variant pack descriptor. + * + * This function creates a `backend_descriptor` object representing a CUDNN backend variant pack + * descriptor. The variant pack descriptor is configured with the provided device pointers, unique + * IDs, and a workspace pointer. + * + * @param[out] variant_pack The created `backend_descriptor` object representing the variant pack. + * @param device_ptrs A vector of device pointers to be associated with the variant pack. + * @param uids A vector of unique IDs to be associated with the variant pack. + * @param workspace_ptr A pointer to the workspace memory to be associated with the variant pack. + * @return `error_t` A tuple containing the error code and an optional error message. + * The error code is `error_code_t::OK` on success, or an appropriate error code on failure. + */ +inline error_t +get_workspace_size(ManagedOpaqueDescriptor& engine_config, int64_t& workspace) { +#if CUDNN_VERSION >= 90200 + CHECK_CUDNN_ERROR(detail::get_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_WORKSPACE_SIZE, + CUDNN_TYPE_INT64, + 1, + nullptr, + &workspace)); + return {error_code_t::OK, ""}; +#else + (void)engine_config; + (void)workspace; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_ENGINECFG_WORKSPACE_SIZE is only available starting 9.2."}; +#endif +} + +inline error_t +get_shared_memory_size(ManagedOpaqueDescriptor& engine_config, int32_t& shared_memory_size) { +#if CUDNN_VERSION >= 90200 + CHECK_CUDNN_ERROR(detail::get_attribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_SHARED_MEMORY_USED, + CUDNN_TYPE_INT32, + 1, + nullptr, + &shared_memory_size)); + return {error_code_t::OK, ""}; +#else + (void)engine_config; + (void)shared_memory_size; + return {error_code_t::CUDNN_BACKEND_API_FAILED, + "CUDNN_ATTR_ENGINECFG_SHARED_MEMORY_USED is only available starting 9.2."}; +#endif +} + +} // namespace cudnn_frontend::detail diff --git a/include/cudnn_frontend/context.h b/include/cudnn_frontend/context.h index 641db11..b50c1e1 100644 --- a/include/cudnn_frontend/context.h +++ b/include/cudnn_frontend/context.h @@ -9,6 +9,8 @@ class Context { DataType_t intermediate_data_type = DataType_t::NOT_SET; DataType_t io_data_type = DataType_t::NOT_SET; + std::string name = ""; + public: Context& set_intermediate_data_type(DataType_t const type) { @@ -43,6 +45,17 @@ class Context { return compute_data_type; } + Context& + set_name(std::string const& name_) { + name = name_; + return *this; + } + + std::string + get_name() const { + return name; + } + Context& fill_missing_properties(Context const& global_context) { if (get_compute_data_type() == DataType_t::NOT_SET) { diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index b54838a..f60dc0a 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -130,17 +130,9 @@ class ICudnn { public: error_t get_cudnn_workspace_size_node(int64_t const plan_index, int64_t& cudnn_workspace_size) const { - int64_t candidate = plan_index != -1 ? plan_index : plans.candidate; + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(plan_index)); - RETURN_CUDNN_FRONTEND_ERROR_IF( - (candidate < 0) && (static_cast(plans.execution_plans.size()) <= candidate), - error_code_t::GRAPH_EXECUTION_FAILED, - "Plan index is invalid."); - - RETURN_CUDNN_FRONTEND_ERROR_IF(!(plans.execution_plans[candidate]), - error_code_t::GRAPH_EXECUTION_FAILED, - "No candidate plan found for graph to query worksapce for."); - cudnn_workspace_size = std::max(cudnn_workspace_size, plans.execution_plans[candidate]->getWorkspaceSize()); + cudnn_workspace_size = std::max(cudnn_workspace_size, plans.execution_plans[plan_index]->getWorkspaceSize()); return {error_code_t::OK, ""}; } @@ -154,7 +146,7 @@ class ICudnn { execute_cudnn_plan_with_uid(cudnnHandle_t handle, std::unordered_map const& tensor_uid_to_pointer_map, void* workspace_ptr, - int64_t plan_index = -1) const { + int64_t plan_index) const { // Make sure device pointer is provided for all uids expected for this plan std::vector device_ptrs; std::vector uids; @@ -163,24 +155,16 @@ class ICudnn { RETURN_CUDNN_FRONTEND_ERROR_IF(search == tensor_uid_to_pointer_map.end(), error_code_t::INVALID_VARIANT_PACK, "Uid " + std::to_string(uid) + " does not exist in variant pack."); - device_ptrs.push_back(tensor_uid_to_pointer_map.at(uid)); + device_ptrs.push_back(search->second); uids.push_back(uid); } - int64_t candidate = plan_index != -1 ? plan_index : plans.candidate; - RETURN_CUDNN_FRONTEND_ERROR_IF( - (candidate < 0) && (static_cast(plans.execution_plans.size()) <= candidate), - error_code_t::GRAPH_EXECUTION_FAILED, - "Plan index is invalid."); - - RETURN_CUDNN_FRONTEND_ERROR_IF(!(plans.execution_plans[candidate]), - error_code_t::GRAPH_EXECUTION_FAILED, - "Plan index does not correspond to a valid plan."); + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(plan_index)); - getLogger() << "[cudnn_frontend] INFO: Executing plan at index " << candidate << "." << std::endl; + getLogger() << "[cudnn_frontend] INFO: Executing plan at index " << plan_index << "." << std::endl; CHECK_CUDNN_FRONTEND_ERROR( - detail::execute(handle, plans.execution_plans[candidate].get(), device_ptrs, uids, workspace_ptr)); + detail::execute(handle, plans.execution_plans[plan_index].get(), device_ptrs, uids, workspace_ptr)); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/graph_helpers.h b/include/cudnn_frontend/graph_helpers.h index 823b81d..e773478 100644 --- a/include/cudnn_frontend/graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -24,7 +24,8 @@ enum class [[nodiscard]] error_code_t { CUDA_API_FAILED, CUDNN_BACKEND_API_FAILED, INVALID_CUDA_DEVICE, - HANDLE_ERROR + HANDLE_ERROR, + INVALID_VALUE }; typedef struct [[nodiscard]] error_object { @@ -100,10 +101,8 @@ typedef struct [[nodiscard]] error_object { do { \ if (auto cudnn_retval = x; cudnn_retval != CUDNN_STATUS_SUCCESS) { \ std::stringstream error_msg; \ - char last_error[1024]; \ - detail::get_last_error_string(last_error, sizeof(last_error)); \ - error_msg << #x << " failed with code: " << detail::get_error_string(cudnn_retval) \ - << ", and message: " << last_error; \ + error_msg << #x << " failed with code: " << detail::get_last_error_string_() \ + << ", and message: " << detail::get_error_string(cudnn_retval); \ getLogger() << "[cudnn_frontend] ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__ \ << std::endl; \ return {error_code_t::CUDNN_BACKEND_API_FAILED, error_msg.str()}; \ @@ -140,11 +139,16 @@ NLOHMANN_JSON_SERIALIZE_ENUM(error_code_t, {error_code_t::INVALID_CUDA_DEVICE, "INVALID_CUDA_DEVICE"}, {error_code_t::UNSUPPORTED_GRAPH_FORMAT, "UNSUPPORTED_GRAPH_FORMAT"}, {error_code_t::HANDLE_ERROR, "HANDLE_ERROR"}, + {error_code_t::INVALID_VALUE, "INVALID_VALUE"}, }) static inline std::ostream& operator<<(std::ostream& os, const error_code_t& mode) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB os << json{mode}; +#else + os << int(mode); +#endif return os; } diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index c797145..800c303 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -106,6 +106,12 @@ class Graph : public INode { Graph & set_compute_data_type(DataType_t type); + Graph & + set_name(std::string const &name) { + context.set_name(name); + return *this; + } + std::shared_ptr tensor(Tensor_attributes const &tensor); @@ -261,6 +267,12 @@ class Graph : public INode { return {error_code_t::OK, ""}; } + error_t + build(cudnnHandle_t const &handle, + std::vector const &mode, + BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, + bool const do_multithreaded_builds = false); + error_t build_plans(cudnnHandle_t const &handle, BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, @@ -275,6 +287,12 @@ class Graph : public INode { return *this; } + Graph & + deselect_shared_mem_greater_than(int64_t const workspace) { + plans.set_max_shared_mem_allowed(workspace); + return *this; + } + Graph & deselect_engines(std::vector const &engine_names) { plans.set_barred_names(engine_names); @@ -320,59 +338,203 @@ class Graph : public INode { using INode::deserialize; using INode::serialize; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json &j) const override final { // Different from serialization of other INodes. // Go over each subnode and serialize them. - j["nodes"]; + json full_json; + full_json["context"]["name"] = context.get_name(); + full_json["context"]["compute_data_type"] = context.get_compute_data_type(); + full_json["context"]["intermediate_data_type"] = context.get_intermediate_data_type(); + full_json["context"]["io_data_type"] = context.get_io_data_type(); + + full_json.update(R"( {"tag": "GRAPH"})"_json); + full_json["nodes"]; for (auto const &sub_node : sub_nodes) { json j_sub_node; sub_node->serialize(j_sub_node); - j["nodes"].push_back(j_sub_node); + full_json["nodes"].push_back(j_sub_node); + } + + j["context"] = full_json["context"]; + j["nodes"]; + j["tensors"]; + std::unordered_set tensors; + + for (const auto &sub_node : full_json["nodes"]) { + // Create a short version of the node + auto short_node = sub_node; + short_node["inputs"] = {}; + short_node["outputs"] = {}; + + // Process node inputs + for (const auto &input : sub_node["inputs"]) { + // Extract port_name and tensor_name + auto port_name = input[0].get(); + auto tensor_info = input[1]; + + if (tensor_info.is_null()) { + continue; + } + + std::string tensor_name = tensor_info["name"].get(); + + // Update short_node inputs + short_node["inputs"][port_name] = tensor_name; + + // Check if the tensor is already in the tensors map + if (tensors.find(tensor_name) == tensors.end()) { + // If not, add it to the j["tensors"] + j["tensors"][tensor_name] = tensor_info; + } + } + + // Process node outputs + for (const auto &output : sub_node["outputs"]) { + // Extract port_name and tensor_name + auto port_name = output[0].get(); + auto tensor_info = output[1]; + + if (tensor_info.is_null()) { + continue; + } + + std::string tensor_name = tensor_info["name"].get(); + + // Update short_node outputs + short_node["outputs"][port_name] = tensor_name; + + // Check if the tensor is already in the tensors map + if (tensors.find(tensor_name) == tensors.end()) { + // If not, add it to the j["tensors"] + j["tensors"][tensor_name] = tensor_info; + } + } + + // Add the short_node to j["nodes"] + j["nodes"].push_back(short_node); } }; +#endif // TODO: temparorily placed in graphs class. This function needs to be a free standing function. +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB error_t deserialize(const json &j) { + if (j.contains("context")) { + const auto &j_context = j["context"]; + if (j_context["compute_data_type"].is_null() == false) { + context.set_compute_data_type(j_context["compute_data_type"].get()); + } + if (j_context["intermediate_data_type"].is_null() == false) { + context.set_intermediate_data_type(j_context["intermediate_data_type"].get()); + } + if (j_context["io_data_type"].is_null() == false) { + context.set_io_data_type(j_context["io_data_type"].get()); + } + if (j_context["name"].is_null() == false) { + context.set_name(j_context["name"].get()); + } + } + + std::map> created_tensors; + // Iterate through each sub-node in the full JSON if (j.contains("nodes") && j["nodes"].is_array()) { - for (const auto &j_sub_node : j["nodes"]) { + for (auto j_sub_node : j["nodes"]) { + // Create a JSON object for inputs + json inputs; + + // Iterate through each input of the sub-node + for (auto &[port_name, tensor_name] : j_sub_node["inputs"].items()) { + // Add the input to the inputs JSON object + inputs.push_back({port_name, j["tensors"][tensor_name]}); + } + + // Create a JSON object for outputs + json outputs; + + // Iterate through each output of the sub-node + for (auto &[port_name, tensor_name] : j_sub_node["outputs"].items()) { + // Add the output to the outputs JSON object + outputs.push_back({port_name, j["tensors"][tensor_name]}); + } + + // Replace the original inputs and outputs of the sub-node with the new JSON objects + j_sub_node["inputs"] = inputs; + j_sub_node["outputs"] = outputs; + + auto check_if_pre_created_tensor = [&created_tensors](std::shared_ptr t) { + if (t == nullptr) { + return t; + } + + if (created_tensors.find(t->get_name()) == created_tensors.end()) { + created_tensors.insert({t->get_name(), t}); + return t; + } else { + return created_tensors[t->get_name()]; + } + }; + +#define CHECK_TENSORS(attributes) \ + for (const auto &[key, tensor] : attributes.inputs) { \ + attributes.inputs[key] = check_if_pre_created_tensor(tensor); \ + } \ + for (const auto &[key, tensor] : attributes.outputs) { \ + attributes.outputs[key] = check_if_pre_created_tensor(tensor); \ + } + if (j_sub_node.contains("tag") && j_sub_node["tag"].is_string()) { auto tag = j_sub_node["tag"].get(); if (tag == "CONV_FPROP") { auto conv_fprop_attributes = j_sub_node.get(); + CHECK_TENSORS(conv_fprop_attributes); sub_nodes.emplace_back( - std::make_unique(std::move(conv_fprop_attributes), detail::Context())); + std::make_unique(std::move(conv_fprop_attributes), context)); } else if (tag == "POINTWISE") { auto pointwise_attributes = j_sub_node.get(); + CHECK_TENSORS(pointwise_attributes); sub_nodes.emplace_back( - std::make_unique(std::move(pointwise_attributes), detail::Context())); + std::make_unique(std::move(pointwise_attributes), context)); } else if (tag == "REDUCTION") { auto reduction_attributes = j_sub_node.get(); + CHECK_TENSORS(reduction_attributes); sub_nodes.emplace_back( - std::make_unique(std::move(reduction_attributes), detail::Context())); + std::make_unique(std::move(reduction_attributes), context)); } else if (tag == "SDPA_FWD") { auto sdpa_attributes = j_sub_node.get(); - sub_nodes.emplace_back( - std::make_unique(std::move(sdpa_attributes), detail::Context())); + CHECK_TENSORS(sdpa_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(sdpa_attributes), context)); } else if (tag == "SDPA_BWD") { auto sdpa_bwd_attributes = j_sub_node.get(); + CHECK_TENSORS(sdpa_bwd_attributes); sub_nodes.emplace_back( - std::make_unique(std::move(sdpa_bwd_attributes), detail::Context())); + std::make_unique(std::move(sdpa_bwd_attributes), context)); + } else if (tag == "MATMUL") { + auto matmul_attributes = j_sub_node.get(); + CHECK_TENSORS(matmul_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(matmul_attributes), context)); } } +#undef CHECK_TENSORS } } return {error_code_t::OK, ""}; } +#endif std::string print(void) const { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB std::stringstream ss; json j = *this; - ss << j.dump(4); + ss << j; return ss.str(); +#else + return "print is unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"; +#endif } }; @@ -409,6 +571,19 @@ Graph::build_plans(cudnnHandle_t const &handle, BuildPlanPolicy_t const policy, return {error_code_t::OK, ""}; } +inline error_t +Graph::build(cudnnHandle_t const &handle, + std::vector const &modes, + BuildPlanPolicy_t const policy, + bool const do_multithreaded_builds) { + CHECK_CUDNN_FRONTEND_ERROR(this->validate()); + CHECK_CUDNN_FRONTEND_ERROR(this->build_operation_graph(handle)); + CHECK_CUDNN_FRONTEND_ERROR(this->create_execution_plans(modes)); + CHECK_CUDNN_FRONTEND_ERROR(this->check_support(handle)); + CHECK_CUDNN_FRONTEND_ERROR(this->build_plans(handle, policy, do_multithreaded_builds)); + return {error_code_t::OK, ""}; +} + inline Graph & Graph::set_intermediate_data_type(DataType_t const type) { context.set_intermediate_data_type(type); diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index 279dac8..eaa9d07 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -55,7 +55,7 @@ class Tensor_attributes { stride.empty(), error_code_t::ATTRIBUTE_NOT_SET, "Tensor '" + name + "' strides not set."); RETURN_CUDNN_FRONTEND_ERROR_IF(dim.size() != stride.size(), error_code_t::ATTRIBUTE_NOT_SET, - "Tensor '" + name + "' does not equal dimensinoality in dim and stride."); + "Tensor '" + name + "' does not equal dimensionality in dim and stride."); RETURN_CUDNN_FRONTEND_ERROR_IF( is_virtual && is_pass_by_value, error_code_t::ATTRIBUTE_NOT_SET, @@ -83,10 +83,12 @@ class Tensor_attributes { public: // Serialization functions +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB friend void to_json(nlohmann::json& j, const Tensor_attributes& ta); friend void from_json(const nlohmann::json& j, Tensor_attributes& ta); +#endif Tensor_attributes() = default; @@ -521,11 +523,11 @@ class BN_finalize_attributes : public Attributes { PREV_RUNNING_VAR, MOMENTUM }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { EQ_SCALE, EQ_BIAS, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR }; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(BN_finalize_attributes, name, inputs, outputs) - std::map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(BN_finalize_attributes, name, compute_data_type, inputs, outputs) + std::unordered_map> outputs; BN_finalize_attributes& set_previous_running_stats(std::shared_ptr& mean, @@ -545,11 +547,11 @@ class Genstats_attributes : public Attributes { public: enum class input_names { X }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { SUM, SQ_SUM }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Genstats_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Genstats_attributes, name, compute_data_type, inputs, outputs) }; class Conv_fprop_attributes : public Attributes { @@ -564,11 +566,12 @@ class Conv_fprop_attributes : public Attributes { public: enum class input_names { X, W }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_fprop_attributes, name, + compute_data_type, inputs, outputs, pre_padding, @@ -635,12 +638,12 @@ class Batchnorm_backward_attributes : public Attributes> inputs; + std::unordered_map> inputs; // Only special case where one of the inputs is a vector. std::vector> peer_stats; enum class output_names { DX, DSCALE, DBIAS }; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_backward_attributes, name, inputs, peer_stats, outputs) - std::map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_backward_attributes, name, compute_data_type, inputs, peer_stats, outputs) + std::unordered_map> outputs; Batchnorm_backward_attributes& set_saved_mean_and_inv_variance(std::shared_ptr mean, @@ -664,10 +667,10 @@ class DBN_weight_attributes : public Attributes { public: enum class input_names { DY, X, SCALE, MEAN, INV_VARIANCE }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { DSCALE, DBIAS, EQ_BIAS, EQ_SCALE_DY, EQ_SCALE_X }; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(DBN_weight_attributes, name, inputs, outputs) - std::map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(DBN_weight_attributes, name, compute_data_type, inputs, outputs) + std::unordered_map> outputs; }; class Conv_dgrad_attributes : public Attributes { @@ -682,11 +685,12 @@ class Conv_dgrad_attributes : public Attributes { public: enum class input_names { DY, W }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { DX }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_dgrad_attributes, name, + compute_data_type, inputs, outputs, pre_padding, @@ -755,10 +759,10 @@ class Matmul_attributes : public Attributes { public: enum class input_names { A, B, M_override, N_override, K_override }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { C }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_attributes, name, compute_data_type, inputs, outputs, padding_value) Matmul_attributes& set_m_override(std::shared_ptr const& value) { @@ -794,10 +798,10 @@ class Matmul_fp8_attributes : public Attributes { public: enum class input_names { Descale_A, Descale_B, A, B, Scale_C }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { C, Amax_C }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_fp8_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_fp8_attributes, name, compute_data_type, inputs, outputs) Matmul_fp8_attributes& set_padding(double const padding_val) { @@ -813,16 +817,28 @@ class Pointwise_attributes : public Attributes { friend class INode; PointwiseMode_t mode = PointwiseMode_t::NOT_SET; + std::optional axis; + std::optional relu_lower_clip; + std::optional relu_upper_clip; std::optional relu_lower_clip_slope; public: enum class input_names { IN_0, IN_1, IN_2 }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { OUT_0 }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Pointwise_attributes, name, inputs, outputs, mode, axis) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Pointwise_attributes, + name, + compute_data_type, + inputs, + outputs, + mode, + axis, + relu_lower_clip, + relu_upper_clip, + relu_lower_clip_slope) Pointwise_attributes& set_mode(PointwiseMode_t const value) { @@ -846,6 +862,18 @@ class Pointwise_attributes : public Attributes { this->relu_lower_clip_slope = negative_slope; return *this; } + + Pointwise_attributes& + set_relu_lower_clip(float const value) { + this->relu_lower_clip = value; + return *this; + } + + Pointwise_attributes& + set_relu_upper_clip(float const value) { + this->relu_upper_clip = value; + return *this; + } }; class Instancenorm_backward_attributes : public Attributes { @@ -855,10 +883,10 @@ class Instancenorm_backward_attributes : public Attributes> inputs; + std::unordered_map> inputs; enum class output_names { DX, DSCALE, DBIAS }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_backward_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_backward_attributes, name, compute_data_type, inputs, outputs) Instancenorm_backward_attributes& set_saved_mean_and_inv_variance(std::shared_ptr mean, @@ -876,10 +904,10 @@ class Layernorm_backward_attributes : public Attributes> inputs; + std::unordered_map> inputs; enum class output_names { DX, DSCALE, DBIAS }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_backward_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_backward_attributes, name, compute_data_type, inputs, outputs) Layernorm_backward_attributes& set_saved_mean_and_inv_variance(std::shared_ptr mean, @@ -899,10 +927,10 @@ class Layernorm_attributes : public Attributes { public: enum class input_names { X, SCALE, BIAS, EPSILON }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y, MEAN, INV_VARIANCE }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_attributes, name, inputs, outputs, forward_phase) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Layernorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) Layernorm_attributes& set_forward_phase(NormFwdPhase_t const value) { @@ -926,10 +954,10 @@ class Instancenorm_attributes : public Attributes { public: enum class input_names { X, SCALE, BIAS, EPSILON }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y, MEAN, INV_VARIANCE }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_attributes, name, inputs, outputs, forward_phase) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) Instancenorm_attributes& set_forward_phase(NormFwdPhase_t const value) { @@ -951,12 +979,12 @@ class Batchnorm_attributes : public Attributes { public: enum class input_names { X, SCALE, BIAS, PREV_RUNNING_MEAN, PREV_RUNNING_VAR, EPSILON, MOMENTUM }; - std::map> inputs; + std::unordered_map> inputs; // Only special case where one of the inputs is a vector. std::vector> peer_stats; enum class output_names { Y, MEAN, INV_VARIANCE, NEXT_RUNNING_MEAN, NEXT_RUNNING_VAR }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_attributes, name, inputs, peer_stats, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_attributes, name, compute_data_type, inputs, peer_stats, outputs) Batchnorm_attributes& set_previous_running_stats(std::shared_ptr& mean, @@ -988,10 +1016,10 @@ class Batchnorm_inference_attributes : public Attributes> inputs; + std::unordered_map> inputs; enum class output_names { Y }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_inference_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Batchnorm_inference_attributes, name, compute_data_type, inputs, outputs) }; class Reduction_attributes : public Attributes { @@ -1003,10 +1031,10 @@ class Reduction_attributes : public Attributes { public: enum class input_names { X }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reduction_attributes, name, inputs, outputs, mode) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reduction_attributes, name, compute_data_type, inputs, outputs, mode) std::optional get_mode() const { @@ -1033,9 +1061,9 @@ class Rng_attributes : public Attributes { public: enum class input_names { Seed, Offset }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rng_attributes, name, inputs, @@ -1226,10 +1254,10 @@ class Reshape_attributes : public Attributes { public: enum class input_names { X }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, name, inputs, outputs, dim, stride) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Reshape_attributes, name, compute_data_type, inputs, outputs, dim, stride) std::vector get_dim() const { @@ -1263,10 +1291,10 @@ class Rmsnorm_attributes : public Attributes { public: enum class input_names { X, SCALE, BIAS, EPSILON }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { Y, INV_VARIANCE }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_attributes, name, inputs, outputs, forward_phase) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_attributes, name, compute_data_type, inputs, outputs, forward_phase) Rmsnorm_attributes& set_forward_phase(NormFwdPhase_t const value) { @@ -1296,10 +1324,10 @@ class Rmsnorm_backward_attributes : public Attributes> inputs; + std::unordered_map> inputs; enum class output_names { DX, DSCALE, DBIAS }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_backward_attributes, name, inputs, outputs) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_backward_attributes, name, compute_data_type, inputs, outputs) Rmsnorm_backward_attributes& has_dbias(bool value) { @@ -1431,9 +1459,11 @@ class SDPA_attributes : public Attributes { friend class Graph; std::optional is_inference; - bool alibi_mask = false; - bool padding_mask = false; - bool causal_mask = false; + bool alibi_mask = false; + bool padding_mask = false; + bool causal_mask = false; + bool causal_mask_bottom_right = false; + std::optional sliding_window_length; std::optional dropout_probability; std::optional attn_scale_value; @@ -1449,11 +1479,11 @@ class SDPA_attributes : public Attributes { Seed, Offset, Dropout_mask, - Dropout_scale + Dropout_scale, }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { O, Stats, RNG_DUMP }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_attributes, name, inputs, @@ -1462,8 +1492,10 @@ class SDPA_attributes : public Attributes { alibi_mask, padding_mask, causal_mask, + causal_mask_bottom_right, dropout_probability, - attn_scale_value) + attn_scale_value, + sliding_window_length) SDPA_attributes& set_is_inference(bool const value) { @@ -1519,6 +1551,18 @@ class SDPA_attributes : public Attributes { return *this; } + SDPA_attributes& + set_causal_mask_bottom_right(bool const value) { + causal_mask_bottom_right = value; + return *this; + } + + SDPA_attributes& + set_sliding_window_length(int const value) { + sliding_window_length = value; + return *this; + } + SDPA_attributes& set_dropout(float const probability, std::shared_ptr seed, @@ -1566,10 +1610,10 @@ class SDPA_fp8_attributes : public Attributes { Scale_S, Scale_O, }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { O, Stats, Amax_S, Amax_O }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_attributes, name, @@ -1609,13 +1653,17 @@ class SDPA_backward_attributes : public Attributes { friend class SDPABackwardNode; friend class Graph; - bool alibi_mask = false; - bool padding_mask = false; - bool causal_mask = false; + bool alibi_mask = false; + bool padding_mask = false; + bool causal_mask = false; + bool causal_mask_bottom_right = false; + std::optional sliding_window_length; std::optional dropout_probability; std::optional attn_scale_value; + bool is_deterministic_algorithm = false; + public: enum class input_names { Q, @@ -1632,11 +1680,11 @@ class SDPA_backward_attributes : public Attributes { Offset, Dropout_mask, Dropout_scale, - Dropout_scale_inv + Dropout_scale_inv, }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { dQ, dK, dV, dBias, RNG_DUMP }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_backward_attributes, name, inputs, @@ -1644,8 +1692,11 @@ class SDPA_backward_attributes : public Attributes { alibi_mask, padding_mask, causal_mask, + causal_mask_bottom_right, dropout_probability, - attn_scale_value) + attn_scale_value, + sliding_window_length, + is_deterministic_algorithm) SDPA_backward_attributes& set_attn_scale(std::shared_ptr value) { @@ -1701,6 +1752,18 @@ class SDPA_backward_attributes : public Attributes { return *this; } + SDPA_backward_attributes& + set_causal_mask_bottom_right(bool const value) { + causal_mask_bottom_right = value; + return *this; + } + + SDPA_backward_attributes& + set_sliding_window_length(int const value) { + sliding_window_length = value; + return *this; + } + SDPA_backward_attributes& set_dropout(float const probability, std::shared_ptr seed, @@ -1727,6 +1790,12 @@ class SDPA_backward_attributes : public Attributes { outputs[SDPA_backward_attributes::output_names::RNG_DUMP] = value; return *this; } + + SDPA_backward_attributes& + set_deterministic_algorithm(bool const value) { + is_deterministic_algorithm = value; + return *this; + } }; class SDPA_fp8_backward_attributes : public Attributes { @@ -1759,12 +1828,18 @@ class SDPA_fp8_backward_attributes : public Attributes> inputs; + std::unordered_map> inputs; enum class output_names { dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP }; - std::map> outputs; + std::unordered_map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_backward_attributes, name, inputs, outputs, causal_mask, attn_scale_value) + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_backward_attributes, + name, + compute_data_type, + inputs, + outputs, + causal_mask, + attn_scale_value) SDPA_fp8_backward_attributes& set_attn_scale(std::shared_ptr value) { @@ -1798,10 +1873,10 @@ class Softmax_attributes : public Attributes { public: enum class input_names { P }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { S, Stats, M, Zinv }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Softmax_attributes, name, inputs, outputs, use_stats, use_M_Zinv) + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Softmax_attributes, name, compute_data_type, inputs, outputs, use_stats, use_M_Zinv) Softmax_attributes& has_stats(bool const value) { @@ -1828,12 +1903,13 @@ class Conv_wgrad_attributes : public Attributes { public: enum class input_names { DY, X }; - std::map> inputs; + std::unordered_map> inputs; enum class output_names { DW }; - std::map> outputs; + std::unordered_map> outputs; NLOHMANN_DEFINE_TYPE_INTRUSIVE(Conv_wgrad_attributes, name, + compute_data_type, inputs, outputs, pre_padding, diff --git a/include/cudnn_frontend/node/batchnorm.h b/include/cudnn_frontend/node/batchnorm.h index 96a75f4..f0cdcec 100644 --- a/include/cudnn_frontend/node/batchnorm.h +++ b/include/cudnn_frontend/node/batchnorm.h @@ -188,11 +188,13 @@ class BatchNormNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "BATCHNORM"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/batchnorm_inference.h b/include/cudnn_frontend/node/batchnorm_inference.h index 0d99cd2..06eb63a 100644 --- a/include/cudnn_frontend/node/batchnorm_inference.h +++ b/include/cudnn_frontend/node/batchnorm_inference.h @@ -129,11 +129,13 @@ class BatchnormInferenceNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "BATCHNORM_INFERENCE"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/bn_finalize.h b/include/cudnn_frontend/node/bn_finalize.h index 6593569..c1a89d4 100644 --- a/include/cudnn_frontend/node/bn_finalize.h +++ b/include/cudnn_frontend/node/bn_finalize.h @@ -156,11 +156,13 @@ class BatchNormFinalizeNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "BN_FINALIZE"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index e6d2c8f..5efedf5 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -141,11 +141,13 @@ class DgradNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "CONV_DGRAD"})"_json); } +#endif }; } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/conv_fprop.h b/include/cudnn_frontend/node/conv_fprop.h index f1584e9..2b9c701 100644 --- a/include/cudnn_frontend/node/conv_fprop.h +++ b/include/cudnn_frontend/node/conv_fprop.h @@ -158,11 +158,13 @@ class ConvolutionNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "CONV_FPROP"})"_json); } +#endif }; } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/conv_wgrad.h b/include/cudnn_frontend/node/conv_wgrad.h index 93d04a7..f9e8879 100644 --- a/include/cudnn_frontend/node/conv_wgrad.h +++ b/include/cudnn_frontend/node/conv_wgrad.h @@ -141,11 +141,13 @@ class WgradNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "CONV_WGRAD"})"_json); } +#endif }; } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/dbn.h b/include/cudnn_frontend/node/dbn.h index 16be121..3fb589f 100644 --- a/include/cudnn_frontend/node/dbn.h +++ b/include/cudnn_frontend/node/dbn.h @@ -154,11 +154,13 @@ class DBNNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "DBN"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/dbn_weight.h b/include/cudnn_frontend/node/dbn_weight.h index a5462a4..9dba2c2 100644 --- a/include/cudnn_frontend/node/dbn_weight.h +++ b/include/cudnn_frontend/node/dbn_weight.h @@ -153,11 +153,13 @@ class DBNWeightNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "DBN_WEIGHT"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/dln.h b/include/cudnn_frontend/node/dln.h index c78e3a8..029073f 100644 --- a/include/cudnn_frontend/node/dln.h +++ b/include/cudnn_frontend/node/dln.h @@ -176,11 +176,13 @@ class DLNNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "LAYER_NORM_BPROP"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/genstats.h b/include/cudnn_frontend/node/genstats.h index d7e46bf..579b6c9 100644 --- a/include/cudnn_frontend/node/genstats.h +++ b/include/cudnn_frontend/node/genstats.h @@ -125,11 +125,13 @@ class GenstatsNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "GENSTATS"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/instancenorm.h b/include/cudnn_frontend/node/instancenorm.h index 6c3cfbb..ac9df2c 100644 --- a/include/cudnn_frontend/node/instancenorm.h +++ b/include/cudnn_frontend/node/instancenorm.h @@ -162,11 +162,13 @@ class InstanceNormNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "INSTANCE_NORM"})"_json); } +#endif }; class DINNode : public NodeCRTP { @@ -331,11 +333,13 @@ class DINNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "INSTANCE_NORM_BPROP"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/layernorm.h b/include/cudnn_frontend/node/layernorm.h index ffdc6e3..f3635d6 100644 --- a/include/cudnn_frontend/node/layernorm.h +++ b/include/cudnn_frontend/node/layernorm.h @@ -204,11 +204,13 @@ class LayerNormNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "LAYER_NORM"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index 90d3886..f3efbcc 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -63,6 +63,16 @@ class MatmulNode : public NodeCRTP { c_tensor_dim[1] = a_tensor_dim[1]; // M c_tensor_dim[2] = b_tensor_dim[2]; // N } + + int64_t gemm_start_dim = a_tensor_dim.size() - 2; + c_tensor_dim[gemm_start_dim] = a_tensor_dim[gemm_start_dim]; // M + c_tensor_dim[gemm_start_dim + 1] = b_tensor_dim[gemm_start_dim + 1]; // N + + // Broadcast the batches + for (int64_t i = 0; i < gemm_start_dim; ++i) { + c_tensor_dim[i] = std::max(a_tensor_dim[i], b_tensor_dim[i]); + } + c_tensor->set_dim(c_tensor_dim); } if (c_tensor->get_stride().empty()) { @@ -151,11 +161,13 @@ class MatmulNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "MATMUL"})"_json); } +#endif }; inline void @@ -173,8 +185,42 @@ inline std::shared_ptr INode::matmul(std::shared_ptr a, std::shared_ptr b, Matmul_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } attributes.inputs[Matmul_attributes::input_names::A] = a; attributes.inputs[Matmul_attributes::input_names::B] = b; + + if (a->get_name().empty()) { + a->set_name(attributes.name + "::A"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::B"); + }; + + auto m_override = attributes.inputs.find(Matmul_attributes::input_names::M_override); + auto n_override = attributes.inputs.find(Matmul_attributes::input_names::N_override); + auto k_override = attributes.inputs.find(Matmul_attributes::input_names::K_override); + + if (m_override != attributes.inputs.end()) { + auto tensor = m_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::M_override"); + } + } + if (n_override != attributes.inputs.end()) { + auto tensor = n_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::N_override"); + } + } + if (k_override != attributes.inputs.end()) { + auto tensor = k_override->second; + if (tensor && tensor->get_name().empty()) { + tensor->set_name(attributes.name + "::K_override"); + } + } + auto C = attributes.outputs[Matmul_attributes::output_names::C] = output_tensor(attributes.name + "::C"); sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); diff --git a/include/cudnn_frontend/node/matmul_fp8.h b/include/cudnn_frontend/node/matmul_fp8.h index e53f6d5..c745cb7 100644 --- a/include/cudnn_frontend/node/matmul_fp8.h +++ b/include/cudnn_frontend/node/matmul_fp8.h @@ -104,11 +104,13 @@ class MatmulFP8Node : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "MATMUL_FP8"})"_json); } +#endif }; inline void INode::matmul_fp8(std::shared_ptr a, diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index 9ae2879..976b0de 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -88,12 +88,27 @@ class PointwiseNode : public NodeCRTP { getLogger() << "[cudnn_frontend] INFO: " << "Building PointwiseNode operations " << attributes.name << "..." << std::endl; - auto pointwise_descriptor = cudnn_frontend::PointwiseDescBuilder() - .setAxis(attributes.get_axis().value_or(-1)) - .setReluLowerClipSlope(attributes.relu_lower_clip_slope.value_or(0.0)) - .setComputeType(attributes.compute_data_type) - .setMode(attributes.mode) - .build(); + auto&& pointwise_descriptor_builder = cudnn_frontend::PointwiseDescBuilder(); + + if (attributes.get_axis().has_value()) { + pointwise_descriptor_builder.setAxis(attributes.get_axis().value()); + } + + if (attributes.relu_lower_clip_slope.has_value()) { + pointwise_descriptor_builder.setReluLowerClipSlope(attributes.relu_lower_clip_slope.value()); + } + + if (attributes.relu_lower_clip.has_value()) { + pointwise_descriptor_builder.setReluLowerClip(attributes.relu_lower_clip.value()); + } + + if (attributes.relu_upper_clip.has_value()) { + pointwise_descriptor_builder.setReluUpperClip(attributes.relu_upper_clip.value()); + } + + pointwise_descriptor_builder.setComputeType(attributes.compute_data_type); + pointwise_descriptor_builder.setMode(attributes.mode); + auto pointwise_descriptor = pointwise_descriptor_builder.build(); auto const port_count = get_pointwise_mode_port_count(attributes.mode); @@ -153,11 +168,13 @@ class PointwiseNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "POINTWISE"})"_json); } +#endif }; inline void @@ -182,7 +199,13 @@ INode::pointwise(std::shared_ptr a, inline std::shared_ptr INode::pointwise(std::shared_ptr a, Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = output_tensor(attributes.name + "::OUT_0"); @@ -194,8 +217,17 @@ inline std::shared_ptr INode::pointwise(std::shared_ptr a, std::shared_ptr b, Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::IN_1"); + }; auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = output_tensor(attributes.name + "::OUT_0"); @@ -208,9 +240,21 @@ INode::pointwise(std::shared_ptr a, std::shared_ptr b, std::shared_ptr c, Pointwise_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; attributes.inputs[Pointwise_attributes::input_names::IN_2] = c; + if (a->get_name().empty()) { + a->set_name(attributes.name + "::IN_0"); + }; + if (b->get_name().empty()) { + b->set_name(attributes.name + "::IN_1"); + }; + if (c->get_name().empty()) { + c->set_name(attributes.name + "::IN_2"); + }; auto OUT_0 = attributes.outputs[Pointwise_attributes::output_names::OUT_0] = output_tensor(attributes.name + "::OUT_0"); diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index 0cd7f63..dc40690 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -118,11 +118,13 @@ class ReductionNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "REDUCTION"})"_json); } +#endif }; inline void diff --git a/include/cudnn_frontend/node/resample.h b/include/cudnn_frontend/node/resample.h index 1b687aa..e00e3c7 100644 --- a/include/cudnn_frontend/node/resample.h +++ b/include/cudnn_frontend/node/resample.h @@ -172,11 +172,13 @@ class ResampleNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "RESAMPLE"})"_json); } +#endif }; inline std::array, 2> diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 4a33aae..21983cb 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -122,11 +122,13 @@ class ReshapeNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "RESHAPE"})"_json); } +#endif }; inline std::shared_ptr diff --git a/include/cudnn_frontend/node/rmsnorm.h b/include/cudnn_frontend/node/rmsnorm.h index 3176bb1..d770cba 100644 --- a/include/cudnn_frontend/node/rmsnorm.h +++ b/include/cudnn_frontend/node/rmsnorm.h @@ -145,11 +145,13 @@ class RMSNormNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "RMS_NORM"})"_json); } +#endif }; class DRMSNormNode : public NodeCRTP { @@ -312,11 +314,13 @@ class DRMSNormNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "RMS_NORM_BPROP"})"_json); } +#endif }; } // namespace graph diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index 794d5f0..32c276c 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -136,11 +136,13 @@ class RngNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"( {"tag": "RNG"})"_json); } +#endif }; inline void diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index a5630aa..94de2dc 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -75,28 +75,38 @@ class SDPANode : public NodeCRTP { #undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE // validate backend limitations for the operation + // clang-format off int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + int64_t s_kv = attributes.inputs.at(input_names::K)->get_dim()[2]; int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + + bool const is_ragged = attributes.inputs.at(input_names::Q)->get_ragged_offset() || + attributes.inputs.at(input_names::K)->get_ragged_offset() || + attributes.inputs.at(input_names::V)->get_ragged_offset() || + attributes.outputs.at(output_names::O)->get_ragged_offset(); + + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; + + // validation TODO: + // - validate stats has valid dims + + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "hidden_dim shoud be less than 256 and hidden_dim should be multiple of 8"); + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), error_code_t::GRAPH_NOT_SUPPORTED, - "For group-query attention, number of heads for key and query must be a factor " - "of number of heads for query"); - - if (detail::get_backend_version() >= 90000) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - (d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "Num hidden_dim shoud be less than 256 and hidden_dim should be multiple of 8"); - } else { - RETURN_CUDNN_FRONTEND_ERROR_IF( - (d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); - } + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); // validate options for attn_scale auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); @@ -106,36 +116,9 @@ class SDPANode : public NodeCRTP { "attn_scale with tensor and value cannot be set at the same time."); // validate options for bias mask - auto bias_mask = attributes.inputs.find(input_names::Bias); - if (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr) { - auto bias_mask_dtype = bias_mask->second->get_data_type(); - RETURN_CUDNN_FRONTEND_ERROR_IF((bias_mask_dtype == DataType_t::BOOLEAN), - error_code_t::GRAPH_NOT_SUPPORTED, - "Bias mask data type cannot be boolean"); - } - - auto const& v_dim = attributes.inputs.at(input_names::V)->get_dim(); - auto s_kv = v_dim[2]; - if ((s_kv % 64 != 0) && (!(attributes.padding_mask)) && (detail::get_backend_version() < 90000)) { - RETURN_CUDNN_FRONTEND_ERROR_IF((detail::get_backend_version() <= 8905), - error_code_t::GRAPH_NOT_SUPPORTED, - "s_kv not a multiple of 64 required cudnn version atleast 8.9.5"); - auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); - bool const has_dropout_mask = - (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); - bool const has_dropout = attributes.dropout_probability.has_value() || has_dropout_mask; - RETURN_CUDNN_FRONTEND_ERROR_IF( - has_dropout, - error_code_t::GRAPH_NOT_SUPPORTED, - "s_kv not a multiple of 64 with dropout enabled is not supported with cudnn version below 9.0.0"); - } - - if (((s_kv % 64 != 0) || (d_qk % 64 != 0)) && (detail::get_backend_version() <= 8905)) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - true, - error_code_t::GRAPH_NOT_SUPPORTED, - "s_kv not a multiple of 64 or d not a multiple of 64 is not supported with cudnn version below 8.9.6"); - } + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); // validate options for padding mask auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); @@ -149,31 +132,69 @@ class SDPANode : public NodeCRTP { error_code_t::ATTRIBUTE_NOT_SET, "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + // validate options for bottom right causal mask + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask && attributes.causal_mask_bottom_right, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask and causal mask cannot be both enabled"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask does not support s_q > s_kv. Please virtually slice the Q tensor and pass it as s_q == s_kv"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || is_ragged || attributes.padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_ragged=False, padding_mask=False, is_dropout=False"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with s_q multiple of 64, and s_kv multiple of 64"); + + // validate options for sliding window length + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && attributes.sliding_window_length.value() < 0, + error_code_t::INVALID_VALUE, + "Sliding window length should be greater than or equals to zero when set."); + + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with padding_mask=False, causal_mask=True, is_dropout=False, is_bias=False, is_ragged=False"); + // validate options for dropout mask - auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); - bool const has_dropout_mask = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && has_dropout_mask, + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && is_dropout_custom, error_code_t::ATTRIBUTE_NOT_SET, - "Using both, custom dropout mask and internal-mask generation using dropout " - "probability, is ill-formed."); + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // version specific validation + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.6, s_kv not a multiple of 64 or d not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8907 && (s_kv % 64 != 0) && (!(attributes.padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.7, s_kv not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && ((s_q % 64 != 0) || (s_kv % 64 != 0)) && (attributes.padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.0.0, s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && ((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.0.0, hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90200 && attributes.sliding_window_length.has_value(), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.2.0, sliding window attention is not supported"); - RETURN_CUDNN_FRONTEND_ERROR_IF( - attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, - error_code_t::ATTRIBUTE_NOT_SET, - "Dropout probability cannot be 1 as corresponding scale wont be well formed."); // validate that datatype is set for the graph RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, "Intermediate tensor data type needs to be set as internal tensors require it."); - - if (((s_q % 64 != 0) || (s_kv % 64 != 0)) && (attributes.padding_mask || has_dropout_mask) && - (detail::get_backend_version() < 90000)) { - RETURN_CUDNN_FRONTEND_ERROR_IF(true, - error_code_t::GRAPH_NOT_SUPPORTED, - "s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported " - "with cudnn version below 9.0.0"); - } + // clang-format on CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); return {error_code_t::OK, ""}; @@ -371,7 +392,77 @@ class SDPANode : public NodeCRTP { last_output = padding_mask_output; } - if (attributes.causal_mask) { + if (attributes.causal_mask || attributes.causal_mask_bottom_right) { + std::shared_ptr row_index; + + row_index = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_row_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_index->set_data_type(DataType_t::INT32); + + if (attributes.causal_mask_bottom_right) { + if (attributes.inputs[input_names::SEQ_LEN_KV]) { + row_index = pointwise(row_index, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index = pointwise(row_index, + std::make_shared(static_cast(s_kv)), + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } + row_index->set_data_type(DataType_t::INT32); + + if (attributes.inputs[input_names::SEQ_LEN_Q]) { + row_index = pointwise(row_index, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index = pointwise(row_index, + std::make_shared(static_cast(s_q)), + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } + row_index->set_data_type(DataType_t::INT32); + } + + auto const& col_index = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_col_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_index->set_data_type(DataType_t::INT32); + + auto const& bool_mask = pointwise(row_index, + col_index, + Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN)); + bool_mask->set_data_type(DataType_t::BOOLEAN); + + last_output = + pointwise(last_output, + std::make_shared(std::numeric_limits::lowest()), + bool_mask, + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); + } + + if (attributes.sliding_window_length.has_value()) { auto row_index_attributes = Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); auto const& row_index_output = pointwise(last_output, row_index_attributes); @@ -380,22 +471,38 @@ class SDPANode : public NodeCRTP { Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); auto const& col_index_output = pointwise(last_output, col_index_attributes); + // sliding window length parameter should be of float type + auto const& sliding_window_length = + std::make_shared((float)attributes.sliding_window_length.value()); + + auto add_col_attributes = Pointwise_attributes() + .set_name("add_window_len") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::FLOAT) + .set_axis(3); + + auto const& col_index_lower_output = pointwise(col_index_output, sliding_window_length, add_col_attributes); + auto greater_than_attributes = Pointwise_attributes() - .set_name("row_greater_than_col") - .set_mode(PointwiseMode_t::CMP_GE) + .set_name("greaterthan_rowset_data_type(DataType_t::BOOLEAN); + + auto const& row_lesser_than_col_ws_output = + pointwise(col_index_lower_output, row_index_output, greater_than_attributes); + + row_lesser_than_col_ws_output->set_data_type(DataType_t::BOOLEAN); // Lower attributes to binary select attributes - auto negative_inf_causal = std::make_shared(std::numeric_limits::lowest()); + auto negative_inf_swa = std::make_shared(-1024.0f * 1024.0f * 1024.0f); auto binary_select_attributes = Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); - auto const& causal_mask_output = - pointwise(last_output, negative_inf_causal, row_greater_than_col_output, binary_select_attributes); - last_output = causal_mask_output; + + auto const& swa_mask_output = + pointwise(last_output, negative_inf_swa, row_lesser_than_col_ws_output, binary_select_attributes); + + last_output = swa_mask_output; } // Lower attributes to softmax attributes @@ -534,11 +641,13 @@ class SDPANode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "SDPA_FWD"})"_json); } +#endif }; class SDPABackwardNode : public NodeCRTP { @@ -610,28 +719,40 @@ class SDPABackwardNode : public NodeCRTP { #undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE // validate backend limitations for the operation - int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; + // clang-format off int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; + int64_t s_kv = attributes.inputs.at(input_names::V)->get_dim()[2]; + int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; - int64_t s_kv = attributes.inputs.at(input_names::V)->get_dim()[2]; int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; - RETURN_CUDNN_FRONTEND_ERROR_IF( - (s_q < 64) && detail::get_backend_version() < 90000, - error_code_t::GRAPH_NOT_SUPPORTED, - "Sequence length must be greater than or equal to 64 for cudnn version prior to v9.0.0"); + bool const is_ragged = attributes.inputs.at(input_names::Q)->get_ragged_offset() || + attributes.inputs.at(input_names::K)->get_ragged_offset() || + attributes.inputs.at(input_names::V)->get_ragged_offset() || + attributes.inputs.at(input_names::O)->get_ragged_offset(); - RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "For group-query attention, number of heads for key and query must be a factor " - "of number of heads for query"); + auto const& bias_mask = attributes.inputs.find(input_names::Bias); + bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); + bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); + bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; + + // validation TODO: + // - validate stats has valid dims + // - validate Q and dQ have the same dims + + // validate basic dimension requirements RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), error_code_t::GRAPH_NOT_SUPPORTED, "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); + // validate options for attn_scale auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); @@ -640,14 +761,9 @@ class SDPABackwardNode : public NodeCRTP { "attn_scale with tensor and value cannot be set at the same time."); // validate options for bias mask - auto bias_mask = attributes.inputs.find(input_names::Bias); - bool const has_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); - if (has_bias) { - auto bias_mask_dtype = bias_mask->second->get_data_type(); - RETURN_CUDNN_FRONTEND_ERROR_IF((bias_mask_dtype == DataType_t::BOOLEAN), - error_code_t::GRAPH_NOT_SUPPORTED, - "Bias mask data type cannot be boolean"); - } + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_data_type() == DataType_t::BOOLEAN), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bias mask data type cannot be boolean"); // validate options for padding mask auto const& seq_len_q = attributes.inputs.find(input_names::SEQ_LEN_Q); @@ -661,31 +777,68 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::ATTRIBUTE_NOT_SET, "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + // validate options for bottom right causal mask + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask && attributes.causal_mask_bottom_right, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask and causal mask cannot be both enabled"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && s_q > s_kv, + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask does not support s_q > s_kv. Please virtually slice the Q tensor and pass it as s_q == s_kv"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || is_ragged || attributes.padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_ragged=False, padding_mask=False, is_dropout=False"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Bottom right causal mask is only supported with s_q multiple of 64, and s_kv multiple of 64"); + + // validate options for sliding window length + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && attributes.sliding_window_length.value() < 0, + error_code_t::INVALID_VALUE, + "Sliding window length should be greater than or equals to zero when set."); + + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), + error_code_t::GRAPH_NOT_SUPPORTED, + "Sliding window attention is only supported with padding_mask=False, causal_mask=True, is_dropout=False, is_bias=False, is_ragged=False"); + // validate options for dropout mask - auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); - bool const has_dropout_mask = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && has_dropout_mask, + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && is_dropout_custom, error_code_t::ATTRIBUTE_NOT_SET, - "Using both, custom dropout mask and internal-mask generation using dropout " - "probability, is ill-formed."); + "Using both, custom dropout mask and internal-mask generation using dropout probability, is ill-formed."); - RETURN_CUDNN_FRONTEND_ERROR_IF( - attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, - error_code_t::ATTRIBUTE_NOT_SET, - "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && attributes.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + // version specific validation + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.6, s_kv not a multiple of 64 or d not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8907 && (s_kv % 64 != 0) && (!(attributes.padding_mask)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 8.9.7, s_kv not a multiple of 64 is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && ((s_q % 64 != 0) || (s_kv % 64 != 0)) && (attributes.padding_mask || is_dropout), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.0.0, s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90000 && (s_q < 64), + error_code_t::GRAPH_NOT_SUPPORTED, + " Sequence length must be greater than or equal to 64 for cudnn version prior to v9.0.0"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90200 && attributes.sliding_window_length.has_value(), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.2.0, sliding window attention is not supported"); // validate that datatype is set for the graph RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, "Intermediate tensor data type needs to be set as internal tensors require it."); - - if (((s_q % 64 != 0) || (s_kv % 64 != 0)) && (attributes.padding_mask || has_dropout_mask) && - (detail::get_backend_version() < 90000)) { - RETURN_CUDNN_FRONTEND_ERROR_IF(true, - error_code_t::GRAPH_NOT_SUPPORTED, - "s_q/s_kv not a multiple of 64 with padding/dropout mask is not supported " - "with cudnn version below 9.0.0"); - } + // clang-format on CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); return {error_code_t::OK, ""}; @@ -783,53 +936,60 @@ class SDPABackwardNode : public NodeCRTP { // ---------------------input tensor workarounds--------------------------- - // workspace optimization is only supported on - // cudnn verision >= 8.9.5 - // device version >= hopper - // sizeof(dp tensor) <= max_dp_workspace - - // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=unset - enable workspace opt. until the default 256MB limit. - // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=-1 - always enable workspace opt. - // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=0 - always disable workspace opt. - // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=n - enable workspace opt. until the n byte limit bool use_workspace_opt = false; - struct cudaDeviceProp prop; - CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, 0)); - if ((detail::get_backend_version() >= 8905 && prop.major >= 9) || (detail::get_backend_version() >= 9000)) { - // default upper limit for workspace 256MB - int64_t max_dp_workspace_bytes = 256 * 1024 * 1024; - - // allow setting the upper limit with envvars - char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); - if (env_dp_workspace_limit_char) { - try { - std::string env_dp_workspace_limit_str(env_dp_workspace_limit_char); - max_dp_workspace_bytes = static_cast(std::stoll(env_dp_workspace_limit_str)); - } catch (...) { - RETURN_CUDNN_FRONTEND_ERROR_IF(true, - error_code_t::ATTRIBUTE_NOT_SET, - "Invalid argument for CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT " - "(int64_t; in bytes)"); + if (detail::get_backend_version() >= 8905 && detail::get_backend_version() < 90000) { + // workspace optimization is enabled by default when: + // 8.9.5 <= cudnn version < 9.0.0 + // device >= hopper + // batch * num_heads * seq_len_q * seq_len_kv * 2 <= dP workspace limit + // + // This following environment variable allows you to control the dP workspace limit. + // From cuDNN version 9.0.0, this option is obsolete will be ignored. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=unset - enable workspace opt. until the default 256MB limit. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=-1 - always enable workspace opt. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=0 - always disable workspace opt. + // CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT=n - enable workspace opt. until the n byte limit + struct cudaDeviceProp prop; + CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, 0)); + + // hopper or above + if (prop.major >= 9) { + // default upper limit for workspace 256MB + int64_t max_dp_workspace_bytes = 256 * 1024 * 1024; + + // allow setting the upper limit with envvars + char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); + if (env_dp_workspace_limit_char) { + try { + std::string env_dp_workspace_limit_str(env_dp_workspace_limit_char); + max_dp_workspace_bytes = static_cast(std::stoll(env_dp_workspace_limit_str)); + } catch (...) { + RETURN_CUDNN_FRONTEND_ERROR_IF(true, + error_code_t::ATTRIBUTE_NOT_SET, + "Invalid argument for CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT " + "(int64_t; in bytes)"); + } } - } - int64_t workspace_s_q = ((s_q + 64 - 1) / 64) * 64; - int64_t workspace_s_kv = ((s_kv + 64 - 1) / 64) * 64; - int64_t required_dp_workspace_bytes = b * h_q * workspace_s_q * workspace_s_kv * 2; + int64_t workspace_s_q = ((s_q + 64 - 1) / 64) * 64; + int64_t workspace_s_kv = ((s_kv + 64 - 1) / 64) * 64; + int64_t required_dp_workspace_bytes = b * h_q * workspace_s_q * workspace_s_kv * 2; - if (max_dp_workspace_bytes == -1) { - use_workspace_opt = true; - } else if (max_dp_workspace_bytes == 0) { - use_workspace_opt = false; - } else { - use_workspace_opt = (required_dp_workspace_bytes <= max_dp_workspace_bytes); + if (max_dp_workspace_bytes == -1) { + use_workspace_opt = true; + } else if (max_dp_workspace_bytes == 0) { + use_workspace_opt = false; + } else { + use_workspace_opt = (required_dp_workspace_bytes <= max_dp_workspace_bytes); + } } } - // WAR force dP workspace implementation if dBias is enabled - // since dBias only works with workspace implementation - if (attributes.outputs[output_names::dBias]) { + // Force dP workspace implementation if: + // - dBias is enabled (dBias is only supported on workspace implementation) + // - the user force requests deterministic algorithm + if (attributes.outputs[output_names::dBias] || attributes.is_deterministic_algorithm) { use_workspace_opt = true; } @@ -1000,38 +1160,74 @@ class SDPABackwardNode : public NodeCRTP { Pointwise_attributes().set_name("select_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); } - // Causal Mask DAG - if (attributes.causal_mask) { - auto row_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_row_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32)); - row_idx_output->set_data_type(DataType_t::INT32); - - auto col_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_col_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32)); - col_idx_output->set_data_type(DataType_t::INT32); + if (attributes.causal_mask || attributes.causal_mask_bottom_right) { + std::shared_ptr row_index; + + row_index = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_row_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_index->set_data_type(DataType_t::INT32); + + if (attributes.causal_mask_bottom_right) { + if (attributes.inputs[input_names::SEQ_LEN_KV]) { + row_index = pointwise(row_index, + attributes.inputs[input_names::SEQ_LEN_KV], + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index = pointwise(row_index, + std::make_shared(static_cast(s_kv)), + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } + row_index->set_data_type(DataType_t::INT32); + + if (attributes.inputs[input_names::SEQ_LEN_Q]) { + row_index = pointwise(row_index, + attributes.inputs[input_names::SEQ_LEN_Q], + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } else { + row_index = pointwise(row_index, + std::make_shared(static_cast(s_q)), + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } + row_index->set_data_type(DataType_t::INT32); + } - auto causal_mask_output = pointwise(row_idx_output, - col_idx_output, - Pointwise_attributes() - .set_name("gt_row_col_causal") - .set_mode(PointwiseMode_t::CMP_GE) - .set_compute_data_type(DataType_t::BOOLEAN)); - causal_mask_output->set_data_type(DataType_t::BOOLEAN); - auto negative_inf_causal = std::make_shared(std::numeric_limits::lowest()); + auto const& col_index = pointwise(last_output, + Pointwise_attributes() + .set_name("gen_col_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_index->set_data_type(DataType_t::INT32); + + auto const& bool_mask = pointwise(row_index, + col_index, + Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN)); + bool_mask->set_data_type(DataType_t::BOOLEAN); last_output = pointwise(last_output, - negative_inf_causal, - causal_mask_output, - Pointwise_attributes().set_name("select_causal").set_mode(PointwiseMode_t::BINARY_SELECT)); + std::make_shared(std::numeric_limits::lowest()), + bool_mask, + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); } // last_output = last_output - stats @@ -1089,8 +1285,52 @@ class SDPABackwardNode : public NodeCRTP { Pointwise_attributes().set_name("select_2nd_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); } + if (attributes.sliding_window_length.has_value()) { + auto row_index_attributes = + Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + auto const& row_index_output = pointwise(last_output, row_index_attributes); + + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + auto const& col_index_output = pointwise(last_output, col_index_attributes); + + // sliding window length parameter should be of float type + auto const& sliding_window_length = + std::make_shared((float)attributes.sliding_window_length.value()); + + auto add_col_attributes = Pointwise_attributes() + .set_name("add_window_len") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::FLOAT) + .set_axis(3); + + auto const& col_index_lower_output = pointwise(col_index_output, sliding_window_length, add_col_attributes); + + auto greater_than_attributes = Pointwise_attributes() + .set_name("greaterthan_rowset_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_swa = std::make_shared(std::numeric_limits::lowest()); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + + auto const& swa_mask_output = + pointwise(last_output, negative_inf_swa, row_lesser_than_col_ws_output, binary_select_attributes); + + last_output = swa_mask_output; + } + // last_output = exp(last_output) - last_output = pointwise(last_output, Pointwise_attributes().set_name("exp_s").set_mode(PointwiseMode_t::EXP)); + last_output = pointwise(last_output, Pointwise_attributes().set_name("exp_s").set_mode(PointwiseMode_t::EXP)); + exp_s_output = last_output; // (optional) last_output = last_output * dropout rng_output @@ -1302,11 +1542,13 @@ class SDPABackwardNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "SDPA_BWD"})"_json); } +#endif }; } // namespace cudnn_frontend::graph diff --git a/include/cudnn_frontend/node/sdpa_fp8.h b/include/cudnn_frontend/node/sdpa_fp8.h index 642e1e4..f289400 100644 --- a/include/cudnn_frontend/node/sdpa_fp8.h +++ b/include/cudnn_frontend/node/sdpa_fp8.h @@ -267,11 +267,13 @@ class SDPAFP8Node : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "SDPA_FP8_FWD"})"_json); } +#endif }; } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h index 206a1d3..9ac88ca 100644 --- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -335,11 +335,13 @@ class SDPAFP8BackwardNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; j.update(R"({"tag": "SDPA_FP8_BWD"})"_json); } +#endif }; } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/softmax.h b/include/cudnn_frontend/node/softmax.h index e7e3898..17f2b03 100644 --- a/include/cudnn_frontend/node/softmax.h +++ b/include/cudnn_frontend/node/softmax.h @@ -122,10 +122,12 @@ class SoftmaxNode : public NodeCRTP { return {error_code_t::OK, ""}; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const override final { j = attributes; } +#endif }; inline void diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index ddcd823..92b74ae 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -69,7 +69,7 @@ class INode : public ICudnn { } int64_t - get_cudnn_workspace_size(int64_t plan_index = -1) const { + get_cudnn_workspace_size(int64_t plan_index) const { int64_t cudnn_workspace_size = 0; auto status = get_cudnn_workspace_size_node(plan_index, cudnn_workspace_size); @@ -423,7 +423,7 @@ class INode : public ICudnn { // There are two workspaces: // - cudnn execution plan workspace // - FE node workspace (example: alibiSlope for fmha) - return get_fe_workspace_size() + get_cudnn_workspace_size(); + return get_fe_workspace_size() + get_cudnn_workspace_size(plans.candidate); } int64_t @@ -571,13 +571,15 @@ class INode : public ICudnn { // this is where cudnn backend can start using workspace for its execution plans void* cudnn_workspace = static_cast(workspace) + get_fe_workspace_size(); - CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plan_with_uid(handle, tensor_uid_to_pointer_map, cudnn_workspace)); + CHECK_CUDNN_FRONTEND_ERROR( + execute_cudnn_plan_with_uid(handle, tensor_uid_to_pointer_map, cudnn_workspace, plans.candidate)); return {error_code_t::OK, ""}; } error_t deserialize(cudnnHandle_t handle, std::vector const& data) { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j = json::from_ubjson(data); auto serialized_plan = j["cudnn_backend_data"]; @@ -592,10 +594,16 @@ class INode : public ICudnn { fe_workspace_size = j["fe_workspace_size"]; return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(handle); + CUDNN_FRONTEND_UNUSED(data); + return {error_code_t::GRAPH_NOT_SUPPORTED, "unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"}; +#endif } error_t serialize(std::vector& data) const { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j; serialize(j); @@ -620,28 +628,40 @@ class INode : public ICudnn { data = json::to_ubjson(j); return {error_code_t::OK, ""}; +#else + CUDNN_FRONTEND_UNUSED(data); + return {error_code_t::GRAPH_NOT_SUPPORTED, "unavailable when compiled with CUDNN_FRONTEND_SKIP_JSON_LIB"}; +#endif } INode(detail::Context const& context) : context(context) {} // Make sure each node implements a public serialize function +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB virtual void serialize(json& j) const = 0; +#endif size_t key() { +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j; serialize(j); return std::hash{}(j); +#else + return 1; +#endif } virtual ~INode() = default; }; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB [[maybe_unused]] static void to_json(json& j, const INode& p) { p.serialize(j); } +#endif template class NodeCRTP : public INode { @@ -739,4 +759,4 @@ class NodeCRTP : public INode { } // namespace graph -} // namespace cudnn_frontend \ No newline at end of file +} // namespace cudnn_frontend diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h index 88e07c3..de37ca6 100644 --- a/include/cudnn_frontend/plans.h +++ b/include/cudnn_frontend/plans.h @@ -9,6 +9,7 @@ #include "graph_helpers.h" #include "backend/execution_helpers.h" +#include "backend/plan_helpers.h" namespace cudnn_frontend { @@ -193,11 +194,51 @@ class Execution_plan_list { std::vector> behavior_notes; std::vector barred_indices; - int64_t max_workspace_allowed = std::numeric_limits::max(); + int64_t max_workspace_allowed = std::numeric_limits::max(); + int64_t max_shared_mem_allowed = 1024 * 1024 * 1024; // Crazy high number (2GB) which will never be hit std::vector barred_engine_names = {}; EngineConfigList engine_configs; + error_t + _build_plan_at_index_impl(cudnnHandle_t handle, int64_t index) { + if (execution_plans[index] == nullptr) { + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_execution_plan( + execution_plans[index], engine_configs[index], operation_tag, handle)); + } + + auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; + } + } + return false; + }; + auto const& plan_tag = execution_plans[index]->getTag(); + if (is_blocked(plan_tag, barred_engine_names)) { + barred_indices[index] = true; + + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Deselecting execution plan with name " + plan_tag + " at position " + + std::to_string(index)}; + } + + // workspace check for 9.2+ is already done at engine config level + if (detail::get_backend_version() < 90200) { + if (execution_plans[index]->getWorkspaceSize() > max_workspace_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Workspace size is too large."}; + } + } + + // Sets candidate in case user does not call execute with plan_index later. + candidate = index; + + return {error_code_t::OK, ""}; + } + public: std::vector> execution_plans; // a built plan corresponding to each engine config, irrespective of whether config is @@ -302,7 +343,7 @@ class Execution_plan_list { bool has_barred_note = std::find(numeric_notes[i].begin(), numeric_notes[i].end(), backend_note) != numeric_notes[i].end(); - barred_indices[i] = has_barred_note && valid_note ? !keep : keep; + barred_indices[i] = barred_indices[i] || (has_barred_note && valid_note ? !keep : keep); } } return {error_code_t::OK, ""}; @@ -318,9 +359,9 @@ class Execution_plan_list { for (auto i = 0u; i < engine_configs.size(); i++) { bool has_barred_note = std::find(behavior_notes[i].begin(), behavior_notes[i].end(), backend_note) != - numeric_notes[i].end(); + behavior_notes[i].end(); - barred_indices[i] = has_barred_note && valid_note ? !keep : keep; + barred_indices[i] = barred_indices[i] || (has_barred_note && valid_note ? !keep : keep); } } return {error_code_t::OK, ""}; @@ -331,6 +372,11 @@ class Execution_plan_list { max_workspace_allowed = workspace_allowed; } + void + set_max_shared_mem_allowed(int64_t const smem_allowed) { + max_shared_mem_allowed = smem_allowed; + } + void set_barred_names(std::vector const& engine_names) { barred_engine_names = engine_names; @@ -352,56 +398,81 @@ class Execution_plan_list { } error_t - check_support(cudnnHandle_t handle) { - for (auto i = 0u; i < engine_configs.size(); i++) { - if (barred_indices[i]) { - getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan at position " << i << std::endl; - continue; - } + check_support_at_index(cudnnHandle_t handle, int64_t index) { + // Ignore if the engine config was deselected. + // This usually happens when user deselects by numerical and behavioural notes. - auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { - for (auto const& blocked_name : blocked_names) { - if (full_name.find(blocked_name) != std::string::npos) { - return true; - } + if (barred_indices[index] == true) { + getLogger() << "Deselecting execution plan at position " << index << std::endl; + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(barred_indices[index] == true, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "Deselecting execution plan"); + + // Ignore if engine name was specified to be ignored by the user. + auto is_blocked = [](std::string const& full_name, std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; } - return false; - }; - - auto cfg_tag = detail::get_engine_tag(engine_configs[i]); - if (is_blocked(cfg_tag, barred_engine_names)) { - getLogger() << "[cudnn_frontend] INFO: Deselecting engine_configs " << cfg_tag << std::endl; - barred_indices[i] = true; - execution_plans[i] = nullptr; - continue; } + return false; + }; + auto cfg_tag = detail::get_engine_tag(engine_configs[index]); + if (is_blocked(cfg_tag, barred_engine_names)) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Deselecting execution plan with name " + cfg_tag + " at position " + + std::to_string(index)}; + } - auto const& config = engine_configs[i]; - auto fe_status = detail::create_cudnn_execution_plan(execution_plans[i], config, operation_tag, handle); - getLogger() << "[cudnn_frontend] INFO: Building plan at index " << i << " gave " << fe_status.get_code() - << " with message: " << fe_status.get_message() << std::endl; - - // If a plan is built successfully, set it as a candidate - if (fe_status.is_good()) { - // Filter out execution plans with workspace greater than whats available from user - if (execution_plans[i]->getWorkspaceSize() > max_workspace_allowed) { - barred_indices[i] = true; - execution_plans[i] = nullptr; - getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan at position " << i << std::endl; - continue; - } + if (detail::get_backend_version() >= 90200) { + // Ignore kernels that require larger than tolerable shared memory. + int32_t shared_memory_size = INT32_MAX; + auto status = detail::get_shared_memory_size(engine_configs[index], shared_memory_size); + if (status.is_bad()) { + getLogger() << "[cudnn_frontend] WARN: Unknown Shared memory size, so not deselecting plan at position " + << index << std::endl; + } else if (shared_memory_size > max_shared_mem_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Skipping plan since shared memory violation. Requires " + + std::to_string(shared_memory_size)}; + } - candidate = static_cast(i); - getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << " " << execution_plans[i]->getTag() - << std::endl; + // Filter by workspace can happen at this engine config stage itself. + int64_t workspace_size = INT64_MAX; + CHECK_CUDNN_FRONTEND_ERROR(detail::get_workspace_size(engine_configs[index], workspace_size)); + if (workspace_size > max_workspace_allowed) { + barred_indices[index] = true; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: Skipping plan since workspace violation. Requires " + + std::to_string(workspace_size)}; + } + } + // Else we need to build the config. A successful execution plan build means that check_support succeeded. + else { + CHECK_CUDNN_FRONTEND_ERROR(_build_plan_at_index_impl(handle, index)); + } + + getLogger() << "Check support for index " << index << " passed with cfg " << cfg_tag << std::endl; + // All checks passed for this config, so return success. + return {error_code_t::OK, ""}; + } + error_t + check_support(cudnnHandle_t handle) { + // Go over each engine config and return true when you find the first one that is supported. + for (auto i = 0u; i < engine_configs.size(); i++) { + auto status = check_support_at_index(handle, i); + if (status.is_good()) { return {error_code_t::OK, ""}; } } - // No plans were able to be built. Return error. return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, - "[cudnn_frontend] Error: No execution plans built successfully."}; + "[cudnn_frontend] Error: No execution plans support the graph."}; } error_t @@ -418,33 +489,10 @@ class Execution_plan_list { error_t build_plan_at_index(cudnnHandle_t handle, int64_t index) { - RETURN_CUDNN_FRONTEND_ERROR_IF(barred_indices[index] == true, - error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, - "Chosen plan index has been deselected."); - - if (execution_plans[index] != nullptr && execution_plans[index]->getWorkspaceSize() <= max_workspace_allowed) { - candidate = index; - return {error_code_t::OK, ""}; - }; - - auto fe_status = - detail::create_cudnn_execution_plan(execution_plans[index], engine_configs[index], operation_tag, handle); - - getLogger() << "[cudnn_frontend] INFO: Building plan at index " << index << " gave " << fe_status.get_code() - << " with message: " << fe_status.get_message() << std::endl; - - // Sets candidate in case user does not call execute with plan_index later. - if (fe_status.is_good()) { - if (execution_plans[index]->getWorkspaceSize() <= max_workspace_allowed) { - candidate = index; - } else { - barred_indices[index] = true; - return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, - "[cudnn_frontend] Error: Workspace size is too large."}; - } - } + CHECK_CUDNN_FRONTEND_ERROR(check_support_at_index(handle, index)); + CHECK_CUDNN_FRONTEND_ERROR(_build_plan_at_index_impl(handle, index)); - return fe_status; + return {error_code_t::OK, ""}; } error_t @@ -460,58 +508,29 @@ class Execution_plan_list { } for (auto i = 0u; i < engine_configs.size(); i++) { - if (barred_indices[i]) { - getLogger() << "[cudnn_frontend] INFO: Skipping deselected engine plan at index " << i << std::endl; + auto status = build_plan_at_index(handle, i); + if (status.is_bad()) { + getLogger() << "[cudnn_frontend] WARN: Failed to build plan at " << i << std::endl; continue; } - auto fe_status = - detail::create_cudnn_execution_plan(execution_plans[i], engine_configs[i], operation_tag, handle); - getLogger() << "[cudnn_frontend] INFO: Building plan at index " << i << " gave " << fe_status.get_code() - << " with message: " << fe_status.get_message() << std::endl; - - if (fe_status.is_good()) { - if (execution_plans[i]->getWorkspaceSize() > max_workspace_allowed) { - getLogger() << "[cudnn_frontend] INFO: skipping plan since workspace violation. Requires " - << execution_plans[i]->getWorkspaceSize() << std::endl; - barred_indices[i] = true; - execution_plans[i] = nullptr; - continue; - } - - auto is_blocked = [](std::string const& full_name, - std::vector const& blocked_names) -> bool { - for (auto const& blocked_name : blocked_names) { - if (full_name.find(blocked_name) != std::string::npos) { - return true; - } - } - return false; - }; - - if (is_blocked(execution_plans[i]->getTag(), barred_engine_names)) { - getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan " << execution_plans[i]->getTag() - << std::endl; - barred_indices[i] = true; - execution_plans[i] = nullptr; - continue; - } - // Only set the candidate the first time, as the order of iteration is from highest to lowest priority - if (candidate == -1) { - candidate = static_cast(i); - getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << std::endl; - } - - getLogger() << "[cudnn_frontend] INFO: Built plan at " << i << " " << execution_plans[i]->getTag() - << std::endl; + // Only set the candidate the first time, as the order of iteration is from highest to lowest priority + if (candidate == -1) { + candidate = static_cast(i); + getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << std::endl; + } - // Return from this function as first successfully built plan is found. - if (policy == BuildPlanPolicy_t::HEURISTICS_CHOICE) { - return {error_code_t::OK, ""}; - } + // Return from this function as first successfully built plan is found. + if (policy == BuildPlanPolicy_t::HEURISTICS_CHOICE) { + return {error_code_t::OK, ""}; } } + // Return an error if no execution plans could be built + RETURN_CUDNN_FRONTEND_ERROR_IF(candidate == -1, + error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: No valid execution plans built."); + return {error_code_t::OK, ""}; } @@ -544,7 +563,7 @@ class Execution_plan_list { return a->getExecutionTime() < b->getExecutionTime(); }; - std::set, decltype(plan_cmp)> timed_execution_plans(plan_cmp); + std::multiset, decltype(plan_cmp)> timed_execution_plans(plan_cmp); const int maxIterCount = 100; const float threshhold = 0.95f; @@ -617,6 +636,19 @@ class Execution_plan_list { auto error = autotune_impl(execution_plans, handle, tensor_to_pointer_map, workspace, user_impl); return error; } + + error_t + is_plan_index_executable(int64_t const index) const { + RETURN_CUDNN_FRONTEND_ERROR_IF((index < 0) || (static_cast(execution_plans.size()) <= index), + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " is invalid."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(execution_plans[index] == nullptr, + error_code_t::GRAPH_EXECUTION_FAILED, + "Plan index " + std::to_string(index) + " did not build."); + + return {error_code_t::OK, ""}; + } }; } // namespace graph diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index 22b5716..5e36d43 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -4,7 +4,7 @@ #include "../graph_helpers.h" namespace cudnn_frontend::graph { - +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB NLOHMANN_JSON_SERIALIZE_ENUM(BN_finalize_attributes::input_names, { {BN_finalize_attributes::input_names::SUM, "SUM"}, @@ -248,6 +248,15 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Reduction_attributes::input_names, NLOHMANN_JSON_SERIALIZE_ENUM(Reduction_attributes::output_names, {{Reduction_attributes::output_names::Y, "Y"}}) +NLOHMANN_JSON_SERIALIZE_ENUM(Resample_attributes::input_names, + { + {Resample_attributes::input_names::X, "X"}, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(Resample_attributes::output_names, + {{Resample_attributes::output_names::Y, "Y"}, + {Resample_attributes::output_names::Index, "Index"}}) + NLOHMANN_JSON_SERIALIZE_ENUM(Reshape_attributes::input_names, { {Reshape_attributes::input_names::X, "X"}, @@ -392,5 +401,5 @@ from_json(const nlohmann::json& j, Tensor_attributes& ta) { ta.pass_by_value = j.at("pass_by_value"); } } - +#endif } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend_ConvDesc.h b/include/cudnn_frontend_ConvDesc.h index d79c729..1ddb605 100644 --- a/include/cudnn_frontend_ConvDesc.h +++ b/include/cudnn_frontend_ConvDesc.h @@ -55,7 +55,11 @@ class ConvDesc_v8 : public BackendDescriptor { describe() const override { std::stringstream ss; char sep = ' '; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR :" << " Datatype: " << json{compute_type} +#else + ss << "CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR :" << " Datatype: " << int(compute_type) +#endif << " Mode: " << std::to_string(mode) << " Num Dimensions: " << nDims; ss << " PadLower ["; for (auto i = 0; i < nDims; i++) { diff --git a/include/cudnn_frontend_Errata.h b/include/cudnn_frontend_Errata.h index 38c3d55..f8b591e 100644 --- a/include/cudnn_frontend_Errata.h +++ b/include/cudnn_frontend_Errata.h @@ -32,6 +32,7 @@ namespace cudnn_frontend { // json file is defined by environment variable // CUDNN_ERRATA_JSON_FILE. If the environment variable // is not set the value set in the API is considered. +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB [[maybe_unused]] static bool load_from_config(json &json_handle, const std::string &errata_json) { const char *err_json = get_environment("CUDNN_ERRATA_JSON_FILE"); @@ -48,6 +49,7 @@ load_from_config(json &json_handle, const std::string &errata_json) { ifs >> json_handle; return true; } +#endif /** * @brief Checks the shape of an operation to compare against errata filter height and width for kernel blocking @@ -112,6 +114,7 @@ check_shape(cudnnBackendDescriptor_t &op, return blocked; } +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB template static bool check_rule(const json &json_handle, const std::string &executionPlanTag, cudnnHandle_t handle, T fn) { @@ -248,6 +251,7 @@ check_rule(const json &json_handle, // Takes in an initialzed json handle and checks if it satisfies the // condition for running it. Returns true if the given executionPlanTag // is faulty. + template static bool check_errata(const json &json_handle, const std::string &executionPlanTag, cudnnHandle_t handle, T fn) { @@ -283,5 +287,6 @@ check_errata(const json &json_handle, cudnn_frontend::getLogger() << ". Passed." << std::endl; return false; } +#endif } // namespace cudnn_frontend diff --git a/include/cudnn_frontend_MatMulDesc.h b/include/cudnn_frontend_MatMulDesc.h index 5ac9f79..235782e 100644 --- a/include/cudnn_frontend_MatMulDesc.h +++ b/include/cudnn_frontend_MatMulDesc.h @@ -47,7 +47,11 @@ class MatMulDesc_v8 : public BackendDescriptor { std::string describe() const override { std::stringstream ss; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_MATMUL_DESCRIPTOR :" << " Math precision " << json{compute_type}; +#else + ss << "CUDNN_BACKEND_MATMUL_DESCRIPTOR :" << " Math precision " << int(compute_type); +#endif return ss.str(); } diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index bf2bf6d..0e742c8 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -389,8 +389,12 @@ class OperationBuilder_v8 { build_pointwise_op() { auto status = CUDNN_STATUS_SUCCESS; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB json j = m_operation.pointwise_mode; m_operation.operationTag = j; +#else + m_operation.operationTag = std::to_string((int)m_operation.pointwise_mode); +#endif status = detail::set_attribute(m_operation.pointer->get_backend_descriptor(), CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR, diff --git a/include/cudnn_frontend_PointWiseDesc.h b/include/cudnn_frontend_PointWiseDesc.h index a3bdfd5..b12ba00 100644 --- a/include/cudnn_frontend_PointWiseDesc.h +++ b/include/cudnn_frontend_PointWiseDesc.h @@ -56,8 +56,13 @@ class PointWiseDesc_v8 : public BackendDescriptor { std::string describe() const override { std::stringstream ss; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_POINTWISE_DESCRIPTOR :" << " Mode: " << json{mode} << " Math precision " << json{compute_type}; +#else + ss << "CUDNN_BACKEND_POINTWISE_DESCRIPTOR :" << " Mode: " << int(mode) << " Math precision " + << int(compute_type); +#endif return ss.str(); } diff --git a/include/cudnn_frontend_ReductionDesc.h b/include/cudnn_frontend_ReductionDesc.h index 503cf0d..4b02f08 100644 --- a/include/cudnn_frontend_ReductionDesc.h +++ b/include/cudnn_frontend_ReductionDesc.h @@ -48,8 +48,13 @@ class ReductionDesc_v8 : public BackendDescriptor { std::string describe() const override { std::stringstream ss; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_REDUCTION_DESCRIPTOR :" << " Math precision " << json{compute_type} << " Reduction mode " << json{reduction_mode}; +#else + ss << "CUDNN_BACKEND_REDUCTION_DESCRIPTOR :" << " Math precision " << (int)compute_type << " Reduction mode " + << int(reduction_mode); +#endif return ss.str(); } diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h index b6821ed..8efc324 100644 --- a/include/cudnn_frontend_Resample.h +++ b/include/cudnn_frontend_Resample.h @@ -48,9 +48,16 @@ class ResampleDesc_v8 : public BackendDescriptor { describe() const override { std::stringstream ss; char sep = ','; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_RESAMPLE_DESCRIPTOR: " << "Compute Type: " << json{computeType} << ", Resample Mode: " << json{resample_mode} << ", Spatial Dimensions: " << spatialDim << ", Nan Propagation: " << std::to_string(nanOpt) << ", Padding Mode: " << json{padding_mode}; +#else + ss << "CUDNN_BACKEND_RESAMPLE_DESCRIPTOR: " << "Compute Type: " << int(computeType) + << ", Resample Mode: " << int(resample_mode) << ", Spatial Dimensions: " << spatialDim + << ", Nan Propagation: " << std::to_string(nanOpt) << ", Padding Mode: " << int(padding_mode); +#endif + ss << ", WindowDim: ["; for (auto i = 0; i < spatialDim; i++) { ss << '(' << windowDim[i].numerator << sep << windowDim[i].denominator << ')' << sep; diff --git a/include/cudnn_frontend_Rng.h b/include/cudnn_frontend_Rng.h index 80c3cb2..9856f3e 100644 --- a/include/cudnn_frontend_Rng.h +++ b/include/cudnn_frontend_Rng.h @@ -48,7 +48,11 @@ class RngDesc_v8 : public BackendDescriptor { describe() const override { std::stringstream ss; #if (CUDNN_VERSION >= 8700) +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_RNG_DESCRIPTOR: " << "Distribution Type: " << json{distribution} +#else + ss << "CUDNN_BACKEND_RNG_DESCRIPTOR: " << "Distribution Type: " << int(distribution) +#endif << ", Normal Distribution Mean: " << normal_dist_mean << ", Normal Distribution Standard Deviation: " << normal_dist_std_dev << ", Uniform Distribution Maximum: " << uniform_dist_max diff --git a/include/cudnn_frontend_Tensor.h b/include/cudnn_frontend_Tensor.h index ff42922..68070ad 100644 --- a/include/cudnn_frontend_Tensor.h +++ b/include/cudnn_frontend_Tensor.h @@ -56,7 +56,11 @@ class Tensor_v8 : public BackendDescriptor { std::string describe() const override { std::stringstream ss; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << "CUDNN_BACKEND_TENSOR_DESCRIPTOR :" << " Datatype: " << json{data_type} << " Id: " << std::to_string(id) +#else + ss << "CUDNN_BACKEND_TENSOR_DESCRIPTOR :" << " Datatype: " << int(data_type) << " Id: " << std::to_string(id) +#endif << " nDims " << nDims << " VectorCount: " << vectorCount << " vectorDimension " << vectorDimension; ss << " Dim [ "; for (auto i = 0; i < nDims; i++) { @@ -74,7 +78,11 @@ class Tensor_v8 : public BackendDescriptor { } ss << " ]"; ss << " isVirtual: " << isVirtual << " isByValue: " << isByValue << " Alignment: " << alignment; +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB ss << " reorder_type: " << json{reorder_type}; +#else + ss << " reorder_type: " << int(reorder_type); +#endif return ss.str(); } diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index 6d0ba6e..de6dbf7 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -277,6 +277,19 @@ get_last_error_string(char *message, size_t size) { #endif } +inline std::string +get_last_error_string_() { + const size_t size = 65535; + + std::string message; + + message.reserve(size); + + get_last_error_string(message.data(), size); + + return message; +} + inline cudnnStatus_t set_stream(cudnnHandle_t handle, cudaStream_t stream) { NV_FE_CALL_TO_BACKEND(set_stream, cudnnSetStream, handle, stream); diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index ee790c7..e63472b 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -26,19 +26,21 @@ #include #include #include +#include +#include +#include #include #include -#ifndef CUDNN_FRONTEND_SKIP_NLOHMANN_JSON +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB #include "cudnn_frontend/thirdparty/nlohmann/json.hpp" #endif using json = nlohmann::json; -#include -#include - template <> struct nlohmann::adl_serializer { static void @@ -174,6 +176,11 @@ struct nlohmann::adl_serializer { } }; +#else +#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) +#endif + #include "cudnn_frontend_shim.h" #include "cudnn_backend_base.h" #include "cudnn_frontend_Logging.h" diff --git a/pyproject.toml b/pyproject.toml index 0f10f34..9015e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,12 +14,16 @@ classifiers = [ ] [tool.setuptools] -packages = ["cudnn"] -package-dir = {"" = "python"} +packages = ["cudnn", "include"] +package-dir = {"" = "python", "include" = "include"} +include-package-data = true [project.urls] "Homepage" = "https://github.com/nvidia/cudnn-frontend" "Bug Tracker" = "https://github.com/nvidia/cudnn-frontend/issues" [tool.setuptools.dynamic] -version = {attr = "cudnn.__version__"} \ No newline at end of file +version = {attr = "cudnn.__version__"} + +[tool.setuptools.package-data] +include = ["**/*"] \ No newline at end of file diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index 1bd21b3..cb7fdf6 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -6,6 +6,7 @@ from ._compiled_module import ( backend_version, backend_version_string, + get_last_error_string, destroy_handle, norm_forward_phase, reduction_mode, @@ -24,7 +25,7 @@ from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.4.0" +__version__ = "1.5.0" def _tensor( diff --git a/python/properties.cpp b/python/properties.cpp index 8a05f70..89e3c10 100644 --- a/python/properties.cpp +++ b/python/properties.cpp @@ -21,7 +21,9 @@ class HandleManagement { static std::intptr_t create_handle() { cudnnHandle_t handle; - detail::create_handle(&handle); + auto status = detail::create_handle(&handle); + throw_if( + status != CUDNN_STATUS_SUCCESS, cudnn_frontend::error_code_t::HANDLE_ERROR, "cudnnHandle Create failed"); return reinterpret_cast(handle); } @@ -48,6 +50,11 @@ class HandleManagement { } }; +static std::string +get_last_error_string() { + return detail::get_last_error_string_(); +} + void init_properties(py::module_& m) { py::enum_(m, "data_type") @@ -100,6 +107,8 @@ init_properties(py::module_& m) { return out.str(); }); + m.def("get_last_error_string", &get_last_error_string); + m.def("create_handle", &HandleManagement::create_handle); m.def("destroy_handle", &HandleManagement::destroy_handle); m.def("get_stream", &HandleManagement::get_stream); diff --git a/python/pycudnn.cpp b/python/pycudnn.cpp index 9e56b84..eb14267 100644 --- a/python/pycudnn.cpp +++ b/python/pycudnn.cpp @@ -16,7 +16,7 @@ void *cudnn_dlhandle = nullptr; namespace python_bindings { // Raise C++ exceptions corresponding to C++ FE error codes. -// Pybinds will automatically convert C++ exceptions to pythpn exceptions. +// Pybinds will automatically convert C++ exceptions to python exceptions. void throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) { if (cond == false) return; @@ -37,7 +37,7 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED: throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED: - throw std::runtime_error(error_msg); + throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str()); case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED: throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::CUDA_API_FAILED: @@ -50,6 +50,8 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str()); case cudnn_frontend::error_code_t::HANDLE_ERROR: throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::INVALID_VALUE: + throw std::runtime_error(error_msg); } } diff --git a/python/pygraph/pointwise.cpp b/python/pygraph/pointwise.cpp index 5dabc23..a69ea90 100644 --- a/python/pygraph/pointwise.cpp +++ b/python/pygraph/pointwise.cpp @@ -53,15 +53,28 @@ PyGraph::pointwise_unary(std::shared_ptr PyGraph::relu(std::shared_ptr& input, - float const negative_slope, + std::optional const& negative_slope, + std::optional const& lower_clip, + std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::Pointwise_attributes() .set_compute_data_type(compute_data_type) .set_mode(cudnn_frontend::PointwiseMode_t::RELU_FWD) - .set_relu_lower_clip_slope(negative_slope) .set_name(name); + if (negative_slope.has_value()) { + attributes.set_relu_lower_clip_slope(negative_slope.value()); + } + + if (lower_clip.has_value()) { + attributes.set_relu_lower_clip(lower_clip.value()); + } + + if (upper_clip.has_value()) { + attributes.set_relu_upper_clip(upper_clip.value()); + } + auto OUT_0 = graph.pointwise(input, attributes); return OUT_0; } @@ -84,15 +97,28 @@ PyGraph::gen_index(std::shared_ptr& in std::shared_ptr PyGraph::relu_backward(std::shared_ptr& loss, std::shared_ptr& input, - float const negative_slope, + std::optional const& negative_slope, + std::optional const& lower_clip, + std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::Pointwise_attributes() .set_compute_data_type(compute_data_type) .set_mode(cudnn_frontend::PointwiseMode_t::RELU_BWD) - .set_relu_lower_clip_slope(negative_slope) .set_name(name); + if (negative_slope.has_value()) { + attributes.set_relu_lower_clip_slope(negative_slope.value()); + } + + if (lower_clip.has_value()) { + attributes.set_relu_lower_clip(lower_clip.value()); + } + + if (upper_clip.has_value()) { + attributes.set_relu_upper_clip(upper_clip.value()); + } + auto OUT_0 = graph.pointwise(loss, input, attributes); return OUT_0; } @@ -103,7 +129,7 @@ PyGraph::leaky_relu_backward(std::shared_ptr @@ -111,7 +137,7 @@ PyGraph::leaky_relu(std::shared_ptr& i float const negative_slope, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { - return relu(input, negative_slope, compute_data_type, name); + return relu(input, negative_slope, std::nullopt, std::nullopt, compute_data_type, name); } void @@ -246,7 +272,7 @@ init_pygraph_pointwise_submodule(py::class_& m) { m.def("gen_index", &PyGraph::gen_index, py::arg("input"), - py::arg_v("axis", 0), + py::arg("axis"), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -254,7 +280,7 @@ init_pygraph_pointwise_submodule(py::class_& m) { Args: input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. + axis (int): The axis to generate index for. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): A name for the operation to be performed. @@ -266,7 +292,9 @@ init_pygraph_pointwise_submodule(py::class_& m) { m.def("relu", &PyGraph::relu, py::arg("input"), - py::arg_v("negative_slope", 0.0), + py::arg_v("negative_slope", py::none()), + py::arg_v("lower_clip", py::none()), + py::arg_v("upper_clip", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -274,7 +302,9 @@ init_pygraph_pointwise_submodule(py::class_& m) { Args: input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. + negative_slope (Optional[float]): Sets the lower clip slope value for ReLU. + lower_clip (Optional[float]): Sets the lower clip value for ReLU. + upper_clip (Optional[float]): Sets the upper clip value for ReLU. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): A name for the operation to be performed. @@ -418,7 +448,9 @@ init_pygraph_pointwise_submodule(py::class_& m) { &PyGraph::relu_backward, py::arg("loss"), py::arg("input"), - py::arg_v("negative_slope", 0.0), + py::arg_v("negative_slope", py::none()), + py::arg_v("lower_clip", py::none()), + py::arg_v("upper_clip", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -427,7 +459,9 @@ init_pygraph_pointwise_submodule(py::class_& m) { Args: loss (cudnn_tensor): The loss tensor. input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. + negative_slope (Optional[float]): Sets the lower clip slope value for ReLU. + lower_clip (Optional[float]): Sets the lower clip value for ReLU. + upper_clip (Optional[float]): Sets the upper clip value for ReLU. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): A name for the operation to be performed. diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index 1823b5d..ef17ce7 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -270,6 +270,14 @@ PyGraph::reduction(std::shared_ptr& in return OUT_0; } +std::shared_ptr +PyGraph::reshape(std::shared_ptr& input, std::string const& name) { + auto attributes = cudnn_frontend::graph::Reshape_attributes().set_name(name); + + auto OUT = graph.reshape(input, attributes); + return OUT; +} + void PyGraph::validate() { auto status = graph.validate(); @@ -335,9 +343,20 @@ PyGraph::serialize() const { } void -PyGraph::deserialize(std::vector const& data) { - auto status = graph.deserialize(handle, data); - throw_if(status.is_bad(), status.get_code(), status.get_message()); +PyGraph::deserialize(py::object const& pyobj) { + if (py::isinstance(pyobj)) { + json j = json::parse(pyobj.cast()); + + auto status = graph.deserialize(j); + + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + } else { + std::vector data = pyobj.cast>(); + auto status = graph.deserialize(handle, data); + + throw_if(status.is_bad(), status.get_code(), status.get_message()); + } } void @@ -589,6 +608,21 @@ init_pygraph_submodule(py::module_& m) { Returns: cudnn_tensor: The result of reduction operation. )pbdoc") + .def("reshape", + &PyGraph::reshape, + py::arg("input"), + py::arg_v("name", ""), + R"pbdoc( + Reshape an input tensor to other dimensions without changing the actual memory layout. + These dimensions to reshape to are inferred from output tensor shape. + + Args: + input (cudnn_tensor): The input tensor. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of reshape operation. Please set the dims for the output tensor. + )pbdoc") .def("deselect_numeric_notes", &PyGraph::deselect_numeric_notes) .def("deselect_behavior_notes", &PyGraph::deselect_behavior_notes) .def("select_numeric_notes", &PyGraph::select_numeric_notes) diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index 52f3828..fec7e9c 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -169,7 +169,9 @@ class PyGraph { std::shared_ptr relu(std::shared_ptr& input, - float const negative_slope, + std::optional const& negative_slope, + std::optional const& lower_clip, + std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -182,7 +184,10 @@ class PyGraph { std::shared_ptr relu_backward(std::shared_ptr& loss, std::shared_ptr& input, - float const negative_slope, + + std::optional const& negative_slope, + std::optional const& lower_clip, + std::optional const& upper_clip, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -210,6 +215,9 @@ class PyGraph { cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); + std::shared_ptr + reshape(std::shared_ptr& input, std::string const& name); + std::vector> rmsnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, std::shared_ptr& x, @@ -259,6 +267,8 @@ class PyGraph { std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, + bool const use_causal_mask_bottom_right, + py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, cudnn_frontend::DataType_t const& compute_data_type, @@ -280,8 +290,11 @@ class PyGraph { std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, + bool const use_causal_mask_bottom_right, + py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, + bool const use_deterministic_algorithm, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -397,7 +410,7 @@ class PyGraph { serialize() const; void - deserialize(std::vector const& data); + deserialize(py::object const& pyobj); int64_t get_execution_plan_count() const { diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index 7527f4d..5a28e26 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -24,6 +24,8 @@ PyGraph::sdpa(std::shared_ptr& q, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, + bool const use_causal_mask_bottom_right, + py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, cudnn_frontend::DataType_t const& compute_data_type, @@ -36,6 +38,7 @@ PyGraph::sdpa(std::shared_ptr& q, .set_seq_len_q(seq_len_q) .set_seq_len_kv(seq_len_kv) .set_causal_mask(use_causal_mask) + .set_causal_mask_bottom_right(use_causal_mask_bottom_right) .set_compute_data_type(compute_data_type) .set_name(name); @@ -52,6 +55,11 @@ PyGraph::sdpa(std::shared_ptr& q, } } + if (!sliding_window_length.is_none()) { + int const sliding_window_value = sliding_window_length.cast(); + attributes.set_sliding_window_length(sliding_window_value); + } + if (!dropout.is_none()) { py::tuple dropout_tuple = dropout.cast(); if ((!dropout_tuple) || (dropout_tuple.size() != 3 && dropout_tuple.size() != 2)) { @@ -109,8 +117,11 @@ PyGraph::sdpa_backward(std::shared_ptr std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, bool const use_causal_mask, + bool const use_causal_mask_bottom_right, + py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, + bool const use_deterministic_algorithm, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::SDPA_backward_attributes() @@ -121,6 +132,8 @@ PyGraph::sdpa_backward(std::shared_ptr .set_seq_len_q(seq_len_q) .set_seq_len_kv(seq_len_kv) .set_causal_mask(use_causal_mask) + .set_causal_mask_bottom_right(use_causal_mask_bottom_right) + .set_deterministic_algorithm(use_deterministic_algorithm) .set_compute_data_type(compute_data_type) .set_name(name); @@ -139,6 +152,11 @@ PyGraph::sdpa_backward(std::shared_ptr } } + if (!sliding_window_length.is_none()) { + int const sliding_window_value = sliding_window_length.cast(); + attributes.set_sliding_window_length(sliding_window_value); + } + if (!dropout.is_none()) { if (!py::isinstance(dropout)) { throw std::runtime_error( @@ -296,6 +314,8 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg_v("seq_len_q", nullptr), py::arg_v("seq_len_kv", nullptr), py::arg_v("use_causal_mask", false), + py::arg_v("use_causal_mask_bottom_right", false), + py::arg_v("sliding_window_length", py::none()), py::arg_v("dropout", py::none()), py::arg_v("rng_dump", nullptr), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), @@ -315,7 +335,10 @@ init_pygraph_sdpa_submodule(py::class_& m) { seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. + sliding_window_length (Optional[int]): The length of sliding window. Default is None. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + rng_dump (Optional[cudnn_tensor]): Debug tensor to dump the Philox RNG dropout mask. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -339,8 +362,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg_v("seq_len_q", nullptr), py::arg_v("seq_len_kv", nullptr), py::arg_v("use_causal_mask", false), + py::arg_v("use_causal_mask_bottom_right", false), + py::arg_v("sliding_window_length", py::none()), py::arg_v("dropout", py::none()), py::arg_v("rng_dump", nullptr), + py::arg_v("use_deterministic_algorithm", false), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -361,7 +387,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. + sliding_window_length (Optional[int]): The length of sliding window. Default is None. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + rng_dump (Optional[cudnn_tensor]): Debug tensor to dump the Philox RNG dropout mask. Default is None. + use_deterministic_algorithm (Optional[bool]): Whether to always use deterministic algorithm. Default is False. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index dfb316b..ca7962a 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.18) find_package(Catch2 QUIET) +find_package(Threads) + if(NOT Catch2_FOUND) Include(FetchContent) @@ -20,18 +22,32 @@ include(${CMAKE_SOURCE_DIR}/cmake/cuDNN.cmake) add_executable( samples - cpp/mha.cpp - cpp/convolutions.cpp - cpp/dgrads.cpp - cpp/matmuls.cpp - cpp/batchnorm.cpp - cpp/layernorm.cpp - cpp/rmsnorm.cpp - cpp/wgrads.cpp - cpp/serialization.cpp - cpp/autotuning.cpp - cpp/pointwise.cpp - cpp/resample.cpp + cpp/sdpa/fp16_fwd.cpp + cpp/sdpa/fp16_bwd.cpp + cpp/sdpa/fp16_cached.cpp + cpp/sdpa/fp8_fwd.cpp + cpp/sdpa/fp8_bwd.cpp + + cpp/convolution/fprop.cpp + cpp/convolution/fp8_fprop.cpp + cpp/convolution/int8_fprop.cpp + cpp/convolution/dgrads.cpp + cpp/convolution/wgrads.cpp + + cpp/matmul/matmuls.cpp + cpp/matmul/fp8_matmul.cpp + cpp/matmul/int8_matmul.cpp + cpp/matmul/mixed_matmul.cpp + + cpp/norm/batchnorm.cpp + cpp/norm/layernorm.cpp + cpp/norm/rmsnorm.cpp + + cpp/misc/serialization.cpp + cpp/misc/autotuning.cpp + cpp/misc/parallel_compilation.cpp + cpp/misc/pointwise.cpp + cpp/misc/resample.cpp legacy_samples/conv_sample.cpp legacy_samples/resnet_test_list.cpp @@ -75,10 +91,13 @@ endif() target_link_libraries( samples + PRIVATE Threads::Threads + cudnn_frontend _cudnn_frontend_pch Catch2::Catch2WithMain + CUDNN::cudnn_all ) diff --git a/samples/README.md b/samples/README.md index 75edd0a..29b76f8 100644 --- a/samples/README.md +++ b/samples/README.md @@ -20,5 +20,117 @@ Samples leveraging FE's Python interface are located in [samples/python](python/ ## C++ Interface Samples Samples leveraging FE's C++ interface are located in [samples/cpp](cpp/). +### Building the samples + +``` +mkdir build +cd build +cmake -DCUDNN_PATH=/path/to/cudnn -DCUDAToolkit_ROOT=/path/to/cuda ../ +cmake --build . -j16 +bin/samples +``` + +To run a single sample, for eg. `TEST_CASE("Cached sdpa", "[graph][sdpa][flash]")` + +``` +./bin/samples "Cached sdpa" +``` + +### Scaled dot product attention SDPA examples + +##### [samples/cpp/sdpa](cpp/sdpa) shows how to use cudnn's sdpa operation. + +- [Cached SDPA](cpp/sdpa/fp16_cached.cpp) + +Users are expected to build a graph once and then execute it multiple times. This example shows how to cache cudnn sdpa graph building. + +- [Fwd SDPA](cpp/sdpa/fp16_fwd.cpp) and [Bwd SDPA](cpp/sdpa/fp16_bwd.cpp) + +cudnn's sdpa operation enables various customizations on itself. These examples show how to build a graph with sdpa operation for your own custom sdpa needs. + +- [Fwd FP8 SDPA](cpp/sdpa/fp8_fwd.cpp) and [Bwd SDPA](cpp/sdpa/fp8_bwd.cpp) + +Extends the sdpa sample to fp8 precision. + +### Convolution fusion examples + +##### [samples/cpp/convolution](cpp/convolution/) shows how to use cudnn fprop, dgrad, wgrad operation and some fusions with them. + +- [Fprop](cpp/convolution/fprop.cpp) + +Show cases a simple fprop, fprop with pointwise fusion of scale bias and relu, fprop with bias and relu for channels first layout and fusions before convolution in the form of scale bias relu conv and stats. + +- [Fp8 fprop](cpp/convolution/fp8_fprop.cpp) + +Showcases fp8 convolution with scaling and amax reduction. + +- [Int8 fprop](cpp/convolution/int8_fprop.cpp) + +Showcases Int8 convolution. + +- [Dgrad](cpp/convolution/dgrads.cpp) + +Has samples for simple dgrad, fusion for dgrad + drelu and Dgrad + Drelu + DBNweight fused operation. + +- [Wgrad](cpp/convolution/wgrads.cpp) + +Similar to dgrad was simple wgrad and scale+bias+relu+wgrad fused operation. + +### Matmul fusion examples + +##### [Matmul](cpp/matmul/) showcases different matmul samples. + +- [Matmul fusion](cpp/matmul/matmuls.cpp) + +Has samples for simple Matmul, matmul fusions like matmul+abs, matmul+bias and matmul+scale+bias+relu operation. + +- [Fp8 Matmul](cpp/matmul/fp8_matmul.cpp) + +Showcases fp8 matmul with scaling and amax reduction. + +- [Int8 Matmul](cpp/matmul/int8_matmul.cpp) + +Showcases Int8 mamtul. + +- [Mixed precision matmul](cpp/matmul/mixed_matmul.cpp) + +Mixed precision multiplication between int8 and bf16 data-type with int8 operand being upcasted to bf16 + +### Normaliization examples + +##### [Norm](cpp/norm/) showcases different matmul samples. + +- [LayerNorm](cpp/norm/layernorm.cpp) + +Eg for layernorm training, inference and back propagation + +- [RMSNorm](cpp/norm/layernorm.cpp) + +Eg for rmsnorm training, inference and back propagation + +- [BatchNorm](cpp/norm/batchnorm.cpp) + +Shows different fusions in batch norm fprop and bprop. And split batch norm fusions. + +### Miscellaneous examples + +##### [Misc](cpp/misc/) Miscellaneous samples + +- [Pointwise fusions](cpp/misc/pointwise.cpp) + +pointwise fusions with scalar are shown in this sample. + +- [Resample](cpp/misc/resample.cpp) + +resample fprop operation with different resampling modes. + +- [Serialization](cpp/misc/serialization.cpp) + +How to serialize a graph into a file and read it back on another thread/process. + +- [Autotuning](cpp/misc/autotuning.cpp) + +How to choose the best performing plan among multiple plans suggested by the heuristics. + ## [Deprecated] C++ v0.x Interface Samples Samples leveraging FE's C++ 0.x interface are located in [samples/legacy_samples](legacy_samples/). diff --git a/samples/cpp/dgrads.cpp b/samples/cpp/convolution/dgrads.cpp similarity index 99% rename from samples/cpp/dgrads.cpp rename to samples/cpp/convolution/dgrads.cpp index 36a3654..c1f2379 100644 --- a/samples/cpp/dgrads.cpp +++ b/samples/cpp/convolution/dgrads.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/convolution/fp8_fprop.cpp b/samples/cpp/convolution/fp8_fprop.cpp new file mode 100644 index 0000000..e978582 --- /dev/null +++ b/samples/cpp/convolution/fp8_fprop.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include + +TEST_CASE("Convolution fp8 precision", "[conv][graph]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + if (cudnnGetVersion() < 8600) { + SKIP("TEST REQUIRES minimum cudnn version 8.6.0"); + } + if (check_device_arch_newer_than("hopper") == false) { + SKIP("TEST REQUIRES device hopper arch or newer"); + } + + namespace fe = cudnn_frontend; + // conv problem size + int64_t n = 16, c = 128, h = 64, w = 64, k = 256, r = 1, s = 1; + + // Initialize input tensors with int8_t as proxy for fp8 + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph->tensor(fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({n, c, h, w}) + .set_stride({c * h * w, 1, c * w, c}) + .set_data_type(fe::DataType_t::FP8_E4M3)); + + auto W = graph->tensor(fe::graph::Tensor_attributes() + .set_name("filter") + .set_dim({k, c, r, s}) + .set_stride({c * r * s, 1, c * s, c}) + .set_data_type(fe::DataType_t::FP8_E4M3)); + + auto conv_options = fe::graph::Conv_fprop_attributes().set_padding({0, 0}).set_stride({1, 1}).set_dilation({1, 1}); + auto conv_output_fp8 = graph->conv_fprop(X, W, conv_options); + + auto descale_x = graph->tensor(fe::graph::Tensor_attributes() + .set_name("descale_x") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto descale_w = graph->tensor(fe::graph::Tensor_attributes() + .set_name("descale_w") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto scale_y = graph->tensor(fe::graph::Tensor_attributes() + .set_name("scale_y") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto scale_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::MUL); + auto after_descale_x = graph->pointwise(conv_output_fp8, descale_x, scale_options); + auto after_descale_w = graph->pointwise(after_descale_x, descale_w, scale_options); + auto Y = graph->pointwise(after_descale_w, scale_y, scale_options); + + Y->set_output(true).set_data_type(fe::DataType_t::FP8_E4M3); + + auto amax = graph->reduction(after_descale_w, + fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(fe::DataType_t::FLOAT)); + + amax->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({1, 1, 1, 1}); + + REQUIRE(graph->validate().is_good()); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph->build_operation_graph(handle).is_good()); + REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph->check_support(handle).is_good()); + + REQUIRE(graph->build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + // Use int8_t as proxy for fp8 + Surface X_gpu(n * c * h * w, false); + Surface W_gpu(k * c * r * s, false); + Surface Y_gpu(n * k * h * w, false); + + Surface X_descale_gpu(1, false); + Surface W_descale_gpu(1, false); + Surface Y_scale_gpu(1, false); + Surface amax_gpu(1, false); + + Surface workspace(graph->get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {X, X_gpu.devPtr}, + {W, W_gpu.devPtr}, + {Y, Y_gpu.devPtr}, + {descale_x, X_descale_gpu.devPtr}, + {descale_w, W_descale_gpu.devPtr}, + {scale_y, Y_scale_gpu.devPtr}, + {amax, amax_gpu.devPtr}}; + + std::cout << graph->print() << std::endl; + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + checkCudnnErr(cudnnDestroy(handle)); +} diff --git a/samples/cpp/convolutions.cpp b/samples/cpp/convolution/fprop.cpp similarity index 69% rename from samples/cpp/convolutions.cpp rename to samples/cpp/convolution/fprop.cpp index 98d8b8b..46c8ae8 100644 --- a/samples/cpp/convolutions.cpp +++ b/samples/cpp/convolution/fprop.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include @@ -293,43 +293,73 @@ TEST_CASE("SBRCS", "[conv][genstats][graph]") { cudnnDestroy(handle); } -TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") { +TEST_CASE("CBR Graph NCHW", "[conv][graph][caching]") { namespace fe = cudnn_frontend; - int64_t n = 1, c = 64, h = 32, w = 32, k = 4, r = 3, s = 3; + int64_t n = 8, c = 32, h = 16, w = 16, k = 64, r = 3, s = 3; - bool const include_identity = true; + bool cache_hit = true; - auto build_new_graph = [=](cudnnHandle_t handle) { + using graph_and_tensors = std::tuple, + std::shared_ptr, // X + std::shared_ptr, // W + std::shared_ptr, // Z + std::shared_ptr, // B + std::shared_ptr // Y + >; + + std::unordered_map user_maintained_cache; + + auto lookup_cache_or_build_graph = [n, c, h, w, k, r, s, &cache_hit, &user_maintained_cache](cudnnHandle_t handle) { auto graph = std::make_shared(); - graph->set_io_data_type(fe::DataType_t::INT8) - .set_intermediate_data_type(fe::DataType_t::INT32) - .set_compute_data_type(fe::DataType_t::INT32); + graph->set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); auto X = graph->tensor(fe::graph::Tensor_attributes() .set_name("image") .set_dim({n, c, h, w}) - .set_stride({c * h * w, 1, c * w, c})); + .set_stride({c * h * w, h * w, w, 1})); auto W = graph->tensor(fe::graph::Tensor_attributes() .set_name("filter") .set_dim({k, c, r, s}) - .set_stride({c * r * s, 1, c * s, c})); + .set_stride({c * r * s, r * s, s, 1})); auto conv_options = fe::graph::Conv_fprop_attributes().set_padding({1, 1}).set_stride({1, 1}).set_dilation({1, 1}); auto conv_output = graph->conv_fprop(X, W, conv_options); - auto Y = conv_output; - if (include_identity) { - auto identity = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::IDENTITY); - Y = graph->pointwise(conv_output, conv_output, identity); - } + auto Z = graph->tensor(fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({n, k, h, w}) + .set_stride({k * h * w, h * w, w, 1})); // Should be p,q - Y->set_output(true).set_data_type(fe::DataType_t::INT32); + auto add_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + auto add_output = graph->pointwise(conv_output, Z, add_options); + + auto B = graph->tensor( + fe::graph::Tensor_attributes().set_name("bias").set_dim({1, k, 1, 1}).set_stride({k, 1, 1, 1})); + auto bias_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + auto bias_output = graph->pointwise(add_output, B, bias_options); + + auto relu_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::RELU_FWD); + auto Y = graph->pointwise(bias_output, relu_options); + Y->set_output(true).set_stride({k * h * w, h * w, w, 1}); REQUIRE(graph->validate().is_good()); + auto key = graph->key(); + + auto it = user_maintained_cache.find(key); + + if (it != user_maintained_cache.end()) { + cache_hit = true; + return it->second; + } + + cache_hit = false; + REQUIRE(graph->build_operation_graph(handle).is_good()); REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -338,136 +368,41 @@ TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") { REQUIRE(graph->build_plans(handle).is_good()); - return std::make_tuple(graph, X, W, Y); + user_maintained_cache.insert({key, std::make_tuple(graph, X, W, Z, B, Y)}); + + return std::make_tuple(graph, X, W, Z, B, Y); }; cudnnHandle_t handle; - -#if (CUDNN_VERSION < 8600) - SKIP("Conv Int8 requires cudnn 8.6 and up"); -#endif - - if (check_device_arch_newer_than("ampere") == false) { - SKIP("Int8 datatype convolutions require Ampere and later architectures"); - } - checkCudnnErr(cudnnCreate(&handle)); - auto [graph, X, W, Y] = build_new_graph(handle); + auto [graph, X, W, Z, B, Y] = lookup_cache_or_build_graph(handle); - Surface x_tensor(n * c * h * w, false); - Surface w_tensor(k * c * r * s, false); - Surface y_tensor(n * k * h * w, false); // Should be p, q. + REQUIRE(cache_hit == false); - std::unordered_map, void*> variant_pack = { - {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {Y, y_tensor.devPtr}}; + Surface x_tensor(n * c * h * w, false); + Surface w_tensor(k * c * r * s, false); + Surface b_tensor(k, false); + Surface y_tensor(n * k * h * w, false); // Should be p, q. + Surface z_tensor(n * k * h * w, false); // Should be p, q. Surface workspace(graph->get_workspace_size(), false); - REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - cudnnDestroy(handle); -} - -TEST_CASE("Convolution fp8 precision", "[matmul][graph]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - } - if (cudnnGetVersion() < 8600) { - SKIP("TEST REQUIRES minimum cudnn version 8.6.0"); - } - if (check_device_arch_newer_than("hopper") == false) { - SKIP("TEST REQUIRES device hopper arch or newer"); - } - - namespace fe = cudnn_frontend; - // conv problem size - int64_t n = 16, c = 128, h = 64, w = 64, k = 256, r = 1, s = 1; - - // Initialize input tensors with int8_t as proxy for fp8 - auto graph = std::make_shared(); - graph->set_io_data_type(fe::DataType_t::HALF) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto X = graph->tensor(fe::graph::Tensor_attributes() - .set_name("image") - .set_dim({n, c, h, w}) - .set_stride({c * h * w, 1, c * w, c}) - .set_data_type(fe::DataType_t::FP8_E4M3)); - - auto W = graph->tensor(fe::graph::Tensor_attributes() - .set_name("filter") - .set_dim({k, c, r, s}) - .set_stride({c * r * s, 1, c * s, c}) - .set_data_type(fe::DataType_t::FP8_E4M3)); - - auto conv_options = fe::graph::Conv_fprop_attributes().set_padding({0, 0}).set_stride({1, 1}).set_dilation({1, 1}); - auto conv_output_fp8 = graph->conv_fprop(X, W, conv_options); - - auto descale_x = graph->tensor(fe::graph::Tensor_attributes() - .set_name("descale_x") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - auto descale_w = graph->tensor(fe::graph::Tensor_attributes() - .set_name("descale_w") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - auto scale_y = graph->tensor(fe::graph::Tensor_attributes() - .set_name("scale_y") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - auto scale_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::MUL); - auto after_descale_x = graph->pointwise(conv_output_fp8, descale_x, scale_options); - auto after_descale_w = graph->pointwise(after_descale_x, descale_w, scale_options); - auto Y = graph->pointwise(after_descale_w, scale_y, scale_options); - - Y->set_output(true).set_data_type(fe::DataType_t::FP8_E4M3); - - auto amax = graph->reduction(after_descale_w, - fe::graph::Reduction_attributes() - .set_mode(fe::ReductionMode_t::AMAX) - .set_compute_data_type(fe::DataType_t::FLOAT)); - - amax->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({1, 1, 1, 1}); - - REQUIRE(graph->validate().is_good()); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + std::unordered_map, void*> variant_pack = { + {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {B, b_tensor.devPtr}, {Z, z_tensor.devPtr}, {Y, y_tensor.devPtr}}; - REQUIRE(graph->build_operation_graph(handle).is_good()); - REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - REQUIRE(graph->check_support(handle).is_good()); + auto [graph_, X_, W_, Z_, B_, Y_] = lookup_cache_or_build_graph(handle); - REQUIRE(graph->build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + std::unordered_map, void*> variant_pack_ = {{X_, x_tensor.devPtr}, + {W_, w_tensor.devPtr}, + {B_, b_tensor.devPtr}, + {Z_, z_tensor.devPtr}, + {Y_, y_tensor.devPtr}}; - // Use int8_t as proxy for fp8 - Surface X_gpu(n * c * h * w, false); - Surface W_gpu(k * c * r * s, false); - Surface Y_gpu(n * k * h * w, false); + REQUIRE(graph_->execute(handle, variant_pack_, workspace.devPtr).is_good()); - Surface X_descale_gpu(1, false); - Surface W_descale_gpu(1, false); - Surface Y_scale_gpu(1, false); - Surface amax_gpu(1, false); + REQUIRE(cache_hit == true); - Surface workspace(graph->get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { - {X, X_gpu.devPtr}, - {W, W_gpu.devPtr}, - {Y, Y_gpu.devPtr}, - {descale_x, X_descale_gpu.devPtr}, - {descale_w, W_descale_gpu.devPtr}, - {scale_y, Y_scale_gpu.devPtr}, - {amax, amax_gpu.devPtr}}; - - std::cout << graph->print() << std::endl; - REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); -} \ No newline at end of file + cudnnDestroy(handle); +} diff --git a/samples/cpp/convolution/int8_fprop.cpp b/samples/cpp/convolution/int8_fprop.cpp new file mode 100644 index 0000000..7586d2f --- /dev/null +++ b/samples/cpp/convolution/int8_fprop.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include + +TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") { + namespace fe = cudnn_frontend; + + int64_t n = 1, c = 64, h = 32, w = 32, k = 4, r = 3, s = 3; + + bool const include_identity = true; + + auto build_new_graph = [=](cudnnHandle_t handle) { + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::INT8) + .set_intermediate_data_type(fe::DataType_t::INT32) + .set_compute_data_type(fe::DataType_t::INT32); + + auto X = graph->tensor(fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({n, c, h, w}) + .set_stride({c * h * w, 1, c * w, c})); + + auto W = graph->tensor(fe::graph::Tensor_attributes() + .set_name("filter") + .set_dim({k, c, r, s}) + .set_stride({c * r * s, 1, c * s, c})); + + auto conv_options = + fe::graph::Conv_fprop_attributes().set_padding({1, 1}).set_stride({1, 1}).set_dilation({1, 1}); + auto conv_output = graph->conv_fprop(X, W, conv_options); + auto Y = conv_output; + + if (include_identity) { + auto identity = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::IDENTITY); + Y = graph->pointwise(conv_output, conv_output, identity); + } + + Y->set_output(true).set_data_type(fe::DataType_t::INT32); + + REQUIRE(graph->validate().is_good()); + + REQUIRE(graph->build_operation_graph(handle).is_good()); + + REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph->check_support(handle).is_good()); + + REQUIRE(graph->build_plans(handle).is_good()); + + return std::make_tuple(graph, X, W, Y); + }; + + cudnnHandle_t handle; + +#if (CUDNN_VERSION < 8600) + SKIP("Conv Int8 requires cudnn 8.6 and up"); +#endif + + if (check_device_arch_newer_than("ampere") == false) { + SKIP("Int8 datatype convolutions require Ampere and later architectures"); + } + + checkCudnnErr(cudnnCreate(&handle)); + + auto [graph, X, W, Y] = build_new_graph(handle); + + Surface x_tensor(n * c * h * w, false); + Surface w_tensor(k * c * r * s, false); + Surface y_tensor(n * k * h * w, false); // Should be p, q. + + std::unordered_map, void*> variant_pack = { + {X, x_tensor.devPtr}, {W, w_tensor.devPtr}, {Y, y_tensor.devPtr}}; + + Surface workspace(graph->get_workspace_size(), false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + cudnnDestroy(handle); +} diff --git a/samples/cpp/wgrads.cpp b/samples/cpp/convolution/wgrads.cpp similarity index 99% rename from samples/cpp/wgrads.cpp rename to samples/cpp/convolution/wgrads.cpp index dfcec45..7aace2b 100644 --- a/samples/cpp/wgrads.cpp +++ b/samples/cpp/convolution/wgrads.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/matmul/fp8_matmul.cpp b/samples/cpp/matmul/fp8_matmul.cpp new file mode 100644 index 0000000..9b334c3 --- /dev/null +++ b/samples/cpp/matmul/fp8_matmul.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include + +#include + +#include "../../utils/helpers.h" + +#include + +TEST_CASE("Matmul fp8 precision", "[matmul][graph]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + + if ((is_hopper_arch() && cudnnGetVersion() >= 90000) == false) { + SKIP("FP8 gemm not supported pre-Hopper or pre-cudnn-9.0.0"); + } + + namespace fe = cudnn_frontend; + // matmul problem size + int64_t const b = 16; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 128; + + // Initialize input tensors with int8_t as proxy for fp8 + Surface A_gpu(b * m * k, false); + Surface B_gpu(b * k * n, false); + + Surface A_descale_gpu(1, false); + Surface B_descale_gpu(1, false); + + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_data_type(fe::DataType_t::FP8_E4M3); + auto A = graph.tensor(A_attributes); + + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, 1, k}) + .set_data_type(fe::DataType_t::FP8_E4M3); + auto B = graph.tensor(B_attributes); + + auto A_descale_attributes = + fe::graph::Tensor_attributes().set_name("A").set_dim({1, 1, 1}).set_stride({1, 1, 1}).set_data_type( + fe::DataType_t::FLOAT); + auto B_descale_attributes = + fe::graph::Tensor_attributes().set_name("B").set_dim({1, 1, 1}).set_stride({1, 1, 1}).set_data_type( + fe::DataType_t::FLOAT); + + auto A_descale = graph.tensor(A_descale_attributes); + auto B_descale = graph.tensor(B_descale_attributes); + + auto matmul_attributes = + // fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A, B, matmul_attributes); + C->set_data_type(fe::DataType_t::FLOAT); + + // Add scale_A operation + auto pw_0_attributes = fe::graph::Pointwise_attributes() + // .set_name("pw0_Mul") + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto C_after_pw_0 = graph.pointwise(C, A_descale, pw_0_attributes); + C_after_pw_0->set_data_type(fe::DataType_t::FLOAT); + + // Add descale_B operation + auto pw_1_attributes = fe::graph::Pointwise_attributes() + // .set_name("pw1_Mul") + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto C_after_pw_1 = graph.pointwise(C_after_pw_0, B_descale, pw_1_attributes); + C_after_pw_1->set_output(true).set_data_type(fe::DataType_t::BFLOAT16); + + std::cout << graph << std::endl; + REQUIRE(graph.validate().is_good()); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph.check_support(handle).is_good()); + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface C_gpu(b * m * n, false); + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {A, A_gpu.devPtr}, + {B, B_gpu.devPtr}, + {C_after_pw_1, C_gpu.devPtr}, + {A_descale, A_descale_gpu.devPtr}, + {B_descale, B_descale_gpu.devPtr}}; + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + checkCudnnErr(cudnnDestroy(handle)); +} diff --git a/samples/cpp/matmul/int8_matmul.cpp b/samples/cpp/matmul/int8_matmul.cpp new file mode 100644 index 0000000..4c55142 --- /dev/null +++ b/samples/cpp/matmul/int8_matmul.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include + +#include + +#include "../../utils/helpers.h" + +#include + +TEST_CASE("Int8 Matmul", "[matmul][graph]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + namespace fe = cudnn_frontend; + + // matmul problem size + int64_t const b = 16; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 128; + + // Initialize input tensors + Surface A_gpu(b * m * k, false); + // note this is a bf16 tensor, but half is used just for memory allocation + Surface B_gpu(b * k * n, false); + + // Make cudnn graph + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_data_type(fe::DataType_t::INT8); + auto A = graph.tensor(A_attributes); + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, 1, n}) + .set_data_type(fe::DataType_t::INT8); + auto B = graph.tensor(B_attributes); + + auto Bias_attributes = cudnn_frontend::graph::Tensor_attributes() + .set_name("Bias") + .set_dim({b, m, n}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_stride({m * n, n, 1}); + auto Bias = graph.tensor(Bias_attributes); + + // Add MATMUL operation + auto matmul_attributes = cudnn_frontend::graph::Matmul_attributes() + .set_compute_data_type(cudnn_frontend::DataType_t::INT32) + .set_name("GEMM"); + auto C = graph.matmul(A, B, matmul_attributes); + C->set_data_type(cudnn_frontend::DataType_t::FLOAT); + + // Add ADD operation + auto add_attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_name("pw1_add") + .set_mode(cudnn_frontend::PointwiseMode_t::ADD) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + auto C_after_add = graph.pointwise(C, Bias, add_attributes); + C_after_add->set_output(true).set_data_type(cudnn_frontend::DataType_t::FLOAT); + REQUIRE(graph.validate().is_good()); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + if (check_device_arch_newer_than("ampere") && cudnnGetVersion() >= 8906) { + REQUIRE(graph.check_support(handle).is_good()); + } else { + SKIP("int8 gemm not supported pre-Ampere or pre-cudnn-8.9.6"); + } + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + // Run cudnn graph + // note this is a bf16 tensor, but half is used just for memory allocation + Surface C_gpu(b * m * n, false); + Surface Bias_gpu(b * m * n, false); + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C_after_add, C_gpu.devPtr}, {Bias, Bias_gpu.devPtr}}; + + std::cout << graph.print() << std::endl; + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + checkCudnnErr(cudnnDestroy(handle)); +} \ No newline at end of file diff --git a/samples/cpp/matmuls.cpp b/samples/cpp/matmul/matmuls.cpp similarity index 64% rename from samples/cpp/matmuls.cpp rename to samples/cpp/matmul/matmuls.cpp index ed63c7d..33d4af4 100644 --- a/samples/cpp/matmuls.cpp +++ b/samples/cpp/matmul/matmuls.cpp @@ -24,7 +24,7 @@ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include @@ -62,11 +62,11 @@ TEST_CASE("Matmul", "[matmul][graph]") { .set_data_type(fe::DataType_t::BFLOAT16); auto B = graph.tensor(B_attributes); - auto matmul_attributes = - fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); - auto C = graph.matmul(A, B, matmul_attributes); + auto matmul_attributes = fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A, B, matmul_attributes); C->set_output(true).set_data_type(fe::DataType_t::FLOAT); + std::cout << graph << std::endl; REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; @@ -76,7 +76,6 @@ TEST_CASE("Matmul", "[matmul][graph]") { REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); graph.deselect_engines({"eng4_"}); - REQUIRE(graph.check_support(handle).is_good()); REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good()); @@ -90,266 +89,6 @@ TEST_CASE("Matmul", "[matmul][graph]") { checkCudnnErr(cudnnDestroy(handle)); } -TEST_CASE("Matmul fp8 precision", "[matmul][graph]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - } - - if ((is_hopper_arch() && cudnnGetVersion() >= 90000) == false) { - SKIP("FP8 gemm not supported pre-Hopper or pre-cudnn-9.0.0"); - } - - namespace fe = cudnn_frontend; - // matmul problem size - int64_t const b = 16; - int64_t const m = 32; - int64_t const n = 64; - int64_t const k = 128; - - // Initialize input tensors with int8_t as proxy for fp8 - Surface A_gpu(b * m * k, false); - Surface B_gpu(b * k * n, false); - - Surface A_descale_gpu(1, false); - Surface B_descale_gpu(1, false); - - fe::graph::Graph graph{}; - - // Create the two non-virtual input tensors A and B. - // There are read from global memory. - auto A_attributes = fe::graph::Tensor_attributes() - .set_name("A") - .set_dim({b, m, k}) - .set_stride({m * k, k, 1}) - .set_data_type(fe::DataType_t::FP8_E4M3); - auto A = graph.tensor(A_attributes); - - auto B_attributes = fe::graph::Tensor_attributes() - .set_name("B") - .set_dim({b, k, n}) - .set_stride({k * n, 1, k}) - .set_data_type(fe::DataType_t::FP8_E4M3); - auto B = graph.tensor(B_attributes); - - auto A_descale_attributes = - fe::graph::Tensor_attributes().set_name("A").set_dim({1, 1, 1}).set_stride({1, 1, 1}).set_data_type( - fe::DataType_t::FLOAT); - auto B_descale_attributes = - fe::graph::Tensor_attributes().set_name("B").set_dim({1, 1, 1}).set_stride({1, 1, 1}).set_data_type( - fe::DataType_t::FLOAT); - - auto A_descale = graph.tensor(A_descale_attributes); - auto B_descale = graph.tensor(B_descale_attributes); - - auto matmul_attributes = - fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); - auto C = graph.matmul(A, B, matmul_attributes); - C->set_data_type(fe::DataType_t::FLOAT); - - // Add scale_A operation - auto pw_0_attributes = fe::graph::Pointwise_attributes() - .set_name("pw0_Mul") - .set_mode(fe::PointwiseMode_t::MUL) - .set_compute_data_type(fe::DataType_t::FLOAT); - auto C_after_pw_0 = graph.pointwise(C, A_descale, pw_0_attributes); - C_after_pw_0->set_data_type(fe::DataType_t::FLOAT); - - // Add descale_B operation - auto pw_1_attributes = fe::graph::Pointwise_attributes() - .set_name("pw1_Mul") - .set_mode(fe::PointwiseMode_t::MUL) - .set_compute_data_type(fe::DataType_t::FLOAT); - auto C_after_pw_1 = graph.pointwise(C_after_pw_0, B_descale, pw_1_attributes); - C_after_pw_1->set_output(true).set_data_type(fe::DataType_t::BFLOAT16); - - REQUIRE(graph.validate().is_good()); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - REQUIRE(graph.build_operation_graph(handle).is_good()); - REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); - - REQUIRE(graph.check_support(handle).is_good()); - - REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); - - Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { - {A, A_gpu.devPtr}, - {B, B_gpu.devPtr}, - {C_after_pw_1, C_gpu.devPtr}, - {A_descale, A_descale_gpu.devPtr}, - {B_descale, B_descale_gpu.devPtr}}; - - std::cout << graph.print() << std::endl; - REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); -} - -TEST_CASE("Mixed Precision Matmul", "[matmul][graph]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - } - namespace fe = cudnn_frontend; - - // matmul problem size - int64_t const b = 16; - int64_t const m = 32; - int64_t const n = 64; - int64_t const k = 128; - - // Initialize input tensors - Surface A_gpu(b * m * k, false); - // note this is a bf16 tensor, but half is used just for memory allocation - Surface B_gpu(b * k * n, false); - - // Make cudnn graph - fe::graph::Graph graph{}; - - // Create the two non-virtual input tensors A and B. - // There are read from global memory. - auto A_attributes = fe::graph::Tensor_attributes() - .set_name("A") - .set_dim({b, m, k}) - .set_stride({m * k, k, 1}) - .set_data_type(fe::DataType_t::INT8); - auto A = graph.tensor(A_attributes); - auto B_attributes = fe::graph::Tensor_attributes() - .set_name("B") - .set_dim({b, k, n}) - .set_stride({k * n, n, 1}) - .set_data_type(fe::DataType_t::BFLOAT16); - auto B = graph.tensor(B_attributes); - - // Cast the input tensors to required mma precision - auto identity_attributes = fe::graph::Pointwise_attributes() - .set_name("Cast_A") - .set_mode(fe::PointwiseMode_t::IDENTITY) - // INT8->FLOAT->BF16 to maintain precision - .set_compute_data_type(fe::DataType_t::FLOAT); - auto A_casted = graph.pointwise(A, identity_attributes); - A_casted->set_data_type(fe::DataType_t::BFLOAT16); - - auto matmul_attributes = - fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); - auto C = graph.matmul(A_casted, B, matmul_attributes); - C->set_output(true).set_data_type(fe::DataType_t::BFLOAT16); - - REQUIRE(graph.validate().is_good()); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - REQUIRE(graph.build_operation_graph(handle).is_good()); - REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); - - if (is_hopper_arch() && cudnnGetVersion() >= 8906) { - REQUIRE(graph.check_support(handle).is_good()); - } else { - SKIP("int8_bf16 mixed precision gemm not supported pre-Hopper or pre-cudnn-8.9.6"); - } - - REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); - - //// Run cudnn graph - // note this is a bf16 tensor, but half is used just for memory allocation - Surface C_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { - {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; - - std::cout << graph.print() << std::endl; - REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); -} - -TEST_CASE("Int8 Matmul", "[matmul][graph]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - } - namespace fe = cudnn_frontend; - - // matmul problem size - int64_t const b = 16; - int64_t const m = 32; - int64_t const n = 64; - int64_t const k = 128; - - // Initialize input tensors - Surface A_gpu(b * m * k, false); - // note this is a bf16 tensor, but half is used just for memory allocation - Surface B_gpu(b * k * n, false); - - // Make cudnn graph - fe::graph::Graph graph{}; - - // Create the two non-virtual input tensors A and B. - // There are read from global memory. - auto A_attributes = fe::graph::Tensor_attributes() - .set_name("A") - .set_dim({b, m, k}) - .set_stride({m * k, k, 1}) - .set_data_type(fe::DataType_t::INT8); - auto A = graph.tensor(A_attributes); - auto B_attributes = fe::graph::Tensor_attributes() - .set_name("B") - .set_dim({b, k, n}) - .set_stride({k * n, 1, n}) - .set_data_type(fe::DataType_t::INT8); - auto B = graph.tensor(B_attributes); - - auto Bias_attributes = cudnn_frontend::graph::Tensor_attributes() - .set_name("Bias") - .set_dim({b, m, n}) - .set_data_type(cudnn_frontend::DataType_t::FLOAT) - .set_stride({m * n, n, 1}); - auto Bias = graph.tensor(Bias_attributes); - - // Add MATMUL operation - auto matmul_attributes = cudnn_frontend::graph::Matmul_attributes() - .set_compute_data_type(cudnn_frontend::DataType_t::INT32) - .set_name("GEMM"); - auto C = graph.matmul(A, B, matmul_attributes); - C->set_data_type(cudnn_frontend::DataType_t::FLOAT); - - // Add ADD operation - auto add_attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_name("pw1_add") - .set_mode(cudnn_frontend::PointwiseMode_t::ADD) - .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); - auto C_after_add = graph.pointwise(C, Bias, add_attributes); - C_after_add->set_output(true).set_data_type(cudnn_frontend::DataType_t::FLOAT); - REQUIRE(graph.validate().is_good()); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - REQUIRE(graph.build_operation_graph(handle).is_good()); - REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); - - if (check_device_arch_newer_than("ampere") && cudnnGetVersion() >= 8906) { - REQUIRE(graph.check_support(handle).is_good()); - } else { - SKIP("int8 gemm not supported pre-Ampere or pre-cudnn-8.9.6"); - } - - REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); - - // Run cudnn graph - // note this is a bf16 tensor, but half is used just for memory allocation - Surface C_gpu(b * m * n, false); - Surface Bias_gpu(b * m * n, false); - Surface workspace(graph.get_workspace_size(), false); - std::unordered_map, void*> variant_pack = { - {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C_after_add, C_gpu.devPtr}, {Bias, Bias_gpu.devPtr}}; - - std::cout << graph.print() << std::endl; - REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); -} - TEST_CASE("Abs + Matmul", "[matmul][graph]") { namespace fe = cudnn_frontend; @@ -635,3 +374,64 @@ TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); cudnnDestroy(handle); } + +TEST_CASE("Matmul with restricted shared memory", "[matmul][graph]") { + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by currend cudnn version"); + } + namespace fe = cudnn_frontend; + + // matmul problem size + int64_t const b = 1; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 32; + + // Initialize input tensors + Surface A_gpu(b * m * k, false); + Surface B_gpu(b * k * n, false); + + // Make cudnn graph + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + auto A = graph.tensor(A_attributes); + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, n, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + auto B = graph.tensor(B_attributes); + + auto matmul_attributes = fe::graph::Matmul_attributes().set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A, B, matmul_attributes); + C->set_output(true).set_data_type(fe::DataType_t::FLOAT); + + std::cout << graph << std::endl; + REQUIRE(graph.validate().is_good()); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + graph.deselect_shared_mem_greater_than(256 * 1024); + REQUIRE(graph.check_support(handle).is_good()); + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + // Run cudnn graph + Surface C_gpu(b * m * n, false); + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + checkCudnnErr(cudnnDestroy(handle)); +} \ No newline at end of file diff --git a/samples/cpp/matmul/mixed_matmul.cpp b/samples/cpp/matmul/mixed_matmul.cpp new file mode 100644 index 0000000..6a72b67 --- /dev/null +++ b/samples/cpp/matmul/mixed_matmul.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include + +#include + +#include "../../utils/helpers.h" + +#include + +TEST_CASE("Mixed Precision Matmul", "[matmul][graph]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + namespace fe = cudnn_frontend; + + // matmul problem size + int64_t const b = 16; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 128; + + // Initialize input tensors + Surface A_gpu(b * m * k, false); + // note this is a bf16 tensor, but half is used just for memory allocation + Surface B_gpu(b * k * n, false); + + // Make cudnn graph + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_data_type(fe::DataType_t::INT8); + auto A = graph.tensor(A_attributes); + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, n, 1}) + .set_data_type(fe::DataType_t::BFLOAT16); + auto B = graph.tensor(B_attributes); + + // Cast the input tensors to required mma precision + auto identity_attributes = fe::graph::Pointwise_attributes() + .set_name("Cast_A") + .set_mode(fe::PointwiseMode_t::IDENTITY) + // INT8->FLOAT->BF16 to maintain precision + .set_compute_data_type(fe::DataType_t::FLOAT); + auto A_casted = graph.pointwise(A, identity_attributes); + A_casted->set_data_type(fe::DataType_t::BFLOAT16); + + auto matmul_attributes = + fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A_casted, B, matmul_attributes); + C->set_output(true).set_data_type(fe::DataType_t::BFLOAT16); + + REQUIRE(graph.validate().is_good()); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + if (is_hopper_arch() && cudnnGetVersion() >= 8906) { + REQUIRE(graph.check_support(handle).is_good()); + } else { + SKIP("int8_bf16 mixed precision gemm not supported pre-Hopper or pre-cudnn-8.9.6"); + } + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + //// Run cudnn graph + // note this is a bf16 tensor, but half is used just for memory allocation + Surface C_gpu(b * m * n, false); + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; + + std::cout << graph.print() << std::endl; + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + checkCudnnErr(cudnnDestroy(handle)); +} diff --git a/samples/cpp/mha.cpp b/samples/cpp/mha.cpp deleted file mode 100644 index fd8cc33..0000000 --- a/samples/cpp/mha.cpp +++ /dev/null @@ -1,1057 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a - * copy of this software and associated documentation files (the "Software"), - * to deal in the Software without restriction, including without limitation - * the rights to use, copy, modify, merge, publish, distribute, sublicense, - * and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - * DEALINGS IN THE SOFTWARE. - */ - -#include -#include "../utils/helpers.h" - -#include -#include - -namespace fe = cudnn_frontend; - -using graph_and_tensors = std::tuple, - std::shared_ptr, // Q, - std::shared_ptr, // K, - std::shared_ptr, // V, - std::shared_ptr, // Attn_scale, - std::shared_ptr, // Bias, - std::shared_ptr, // SEQ_LEN_Q, - std::shared_ptr, // SEQ_LEN_KV, - std::shared_ptr, // Seed, - std::shared_ptr, // Offset, - std::shared_ptr, // Dropout_mask, - std::shared_ptr, // Dropout_scale - std::shared_ptr, // O - std::shared_ptr // Stats - >; - -using cache_type = std::unordered_map; - -template -auto -lookup_cache_or_build_graph(cudnnHandle_t handle, cache_type& user_maintained_cache, Args... args) { - auto [b, - h, - s_q, - s_kv, - d, - is_inference, - is_attn_scale, - causal_mask, - padding_mask, - alibi_mask, - has_bias, - use_dropout_with_rng, - dropout_probability, - seq_len_override, - use_dropout_mask] = std::make_tuple(args...); - - (void)use_dropout_mask; - - auto graph = std::make_shared(); - graph->set_io_data_type(fe::DataType_t::HALF) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto Q = graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); - auto K = graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h, s_kv, d}) - .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); - auto V = graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h, s_kv, d}) - .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); - - auto attn_scale = is_attn_scale ? graph->tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)) - : nullptr; - - auto sdpa_options = fe::graph::SDPA_attributes().set_name("flash_attention").set_is_inference(is_inference); - - if (is_attn_scale) { - sdpa_options.set_attn_scale(attn_scale); - }; - - sdpa_options.set_alibi_mask(alibi_mask); - sdpa_options.set_causal_mask(causal_mask); - - auto seed = use_dropout_with_rng ? graph->tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)) - : nullptr; - - auto offset = use_dropout_with_rng ? graph->tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)) - : nullptr; - - if (use_dropout_with_rng) { - sdpa_options.set_dropout(dropout_probability, seed, offset); - } - - auto bias = graph->tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({b, 1, s_q, s_kv}) - .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); - - if (has_bias) { - sdpa_options.set_bias(bias); - } - - auto seq_q = seq_len_override ? graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)) - : nullptr; - auto seq_kv = seq_len_override ? graph->tensor(fe::graph::Tensor_attributes() - .set_name("seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)) - : nullptr; - - if (padding_mask) { - sdpa_options.set_padding_mask(true); - } - if (seq_len_override) { - sdpa_options.set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); - } - - auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); - - O->set_output(true).set_dim({b, h, s_q, d}).set_stride({h * d, d, b * h * d, 1}); - - // Check that Stats tensor is real, which is only when its training step - if (is_inference) { - REQUIRE(stats == nullptr); - } else { - stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); - } - - REQUIRE(graph->validate().is_good()); - - auto key = graph->key(); - - auto it = user_maintained_cache.find(key); - - if (it != user_maintained_cache.end()) { - return it->second; - } - - REQUIRE(graph->build_operation_graph(handle).is_good()); - - auto plans = graph->create_execution_plans({fe::HeurMode_t::A}); - - REQUIRE(graph->check_support(handle).is_good()); - - REQUIRE(graph->build_plans(handle).is_good()); - - std::shared_ptr dropout_mask = nullptr; - std::shared_ptr dropout_scale = nullptr; - - user_maintained_cache.insert( - {key, - std::make_tuple( - graph, Q, K, V, attn_scale, bias, seq_q, seq_kv, seed, offset, dropout_mask, dropout_scale, O, stats)}); - - return std::make_tuple( - graph, Q, K, V, attn_scale, bias, seq_q, seq_kv, seed, offset, dropout_mask, dropout_scale, O, stats); -} - -TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - return; - } - - if (cudnnGetVersion() < 8901) { - SKIP("Test requires cuDNN version 8.9.1 or above"); - return; - } - - if (check_device_arch_newer_than("ampere") == false) { - SKIP("Test requires Hopper or above arch."); - return; - } - - int64_t b = 3; // batch size - int64_t h = 4; // head dim - int64_t s_q = 1024; // q tensor is padded to this seq length - int64_t s_kv = 1024; // k and v tensor is padded to this seq length - int64_t d = 128; // hidden dim - bool is_inference = false; - float dropout_probability = 0.1f; - - namespace fe = cudnn_frontend; - fe::graph::Graph mha_graph; - - bool is_attn_scale = true; - bool causal_mask = true; - bool padding_mask = (cudnnGetVersion() >= 8903); - bool alibi_mask = (cudnnGetVersion() >= 8904); - bool use_dropout_with_rng = true; - bool has_bias = (cudnnGetVersion() >= 8903); - bool seq_len_override = padding_mask; - - bool use_dropout_mask = false; - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - cache_type user_maintained_cache; - auto [graph, Q, K, V, attn_scale, bias, seq_q, seq_kv, seed, offset, dropout_mask, dropout_scale, O, stats] = - lookup_cache_or_build_graph(handle, - user_maintained_cache, - b, - h, - s_q, - s_kv, - d, - is_inference, - is_attn_scale, - causal_mask, - padding_mask, - alibi_mask, - has_bias, - use_dropout_with_rng, - dropout_probability, - seq_len_override, - use_dropout_mask); - - (void)dropout_mask; - (void)dropout_scale; - - //// Build variant pack - Surface qkvTensor(b * s_q * 3 * h * d, false); - Surface oTensor(b * s_q * h * d, false); - void* devPtrQ = qkvTensor.devPtr; - void* devPtrK = (qkvTensor.devPtr + d); - void* devPtrV = (qkvTensor.devPtr + 2 * d); - void* devPtrO = oTensor.devPtr; - - float attn_scale_cpu = 0.5f; - - Surface bTensor(b * 1 * s_q * s_kv, false); - - int32_t scaleSize = 1; - int32_t seed_value = 123456; - Surface dropoutSeed(scaleSize, false, seed_value); - Surface dropoutOffset(scaleSize, false, (int32_t)1); - - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, - {K, devPtrK}, - {V, devPtrV}, - {attn_scale, &attn_scale_cpu}, - {bias, bTensor.devPtr}, - {seed, dropoutSeed.devPtr}, - {offset, dropoutOffset.devPtr}, - {O, devPtrO}}; - - if (seq_len_override) { - Surface devActualSeqlenQ(b, false); - Surface devActualSeqlenKV(b, false); - std::vector hostActualSeqlenQ(b, 20); - std::vector hostActualSeqlenKV(b, 20); - - checkCudaErr(cudaMemcpy(devActualSeqlenQ.devPtr, - hostActualSeqlenQ.data(), - sizeof(hostActualSeqlenQ[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaMemcpy(devActualSeqlenKV.devPtr, - hostActualSeqlenKV.data(), - sizeof(hostActualSeqlenKV[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaDeviceSynchronize()); - - variant_pack[seq_q] = devActualSeqlenQ.devPtr; - variant_pack[seq_kv] = devActualSeqlenKV.devPtr; - } - - Surface statsTensor(b * h * s_q * 1, false); - if (is_inference == false) { - variant_pack[stats] = statsTensor.devPtr; - } - - Surface workspace(graph->get_workspace_size(), false); - REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} - -TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - return; - } - - if (cudnnGetVersion() < 8903) { - SKIP("Test requires cuDNN version 8.9.3 or above"); - return; - } - - if (check_device_arch_newer_than("ampere") == false) { - SKIP("Test requires Hopper or above arch."); - return; - } - - int64_t b = 3; // batch size - int64_t h = 4; // head dim - int64_t s_q = 1024; // q tensor is padded to this seq length - int64_t s_kv = 1024; // k and v tensor is padded to this seq length - int64_t d = 128; // hidden dim - bool is_inference = false; - - bool is_attn_scale = true; - bool causal_mask = true; - bool padding_mask = false; - bool alibi_mask = (cudnnGetVersion() >= 8904); - bool use_dropout_with_rng = false; - bool has_bias = (cudnnGetVersion() >= 8903); - bool seq_len_override = false; - - bool use_dropout_mask = false; - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - cache_type user_maintained_cache; - auto [graph, Q, K, V, attn_scale, bias, seq_q, seq_kv, seed, offset, dropout_mask, dropout_scale, O, stats] = - lookup_cache_or_build_graph(handle, - user_maintained_cache, - b, - h, - s_q, - s_kv, - d, - is_inference, - is_attn_scale, - causal_mask, - padding_mask, - alibi_mask, - has_bias, - use_dropout_with_rng, - 0.0f, - seq_len_override, - use_dropout_mask); - - (void)seq_q; - (void)seq_kv; - (void)seed; - (void)offset; - (void)dropout_mask; - (void)dropout_scale; - - //// Build variant pack - Surface qkvTensor(b * s_q * 3 * h * d, false); - Surface oTensor(b * s_q * h * d, false); - void* devPtrQ = qkvTensor.devPtr; - void* devPtrK = (qkvTensor.devPtr + d); - void* devPtrV = (qkvTensor.devPtr + 2 * d); - void* devPtrO = oTensor.devPtr; - - float attn_scale_cpu = 0.5f; - - Surface bTensor(b * 1 * s_q * s_kv, false); - - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &attn_scale_cpu}, {bias, bTensor.devPtr}, {O, devPtrO}}; - - Surface statsTensor(b * h * s_q * 1, false); - if (is_inference == false) { - variant_pack[stats] = statsTensor.devPtr; - } - - Surface workspace(graph->get_workspace_size(), false); - REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} - -TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { - if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - return; - } - if (cudnnGetVersion() < 8903) { - SKIP("Test requires cuDNN version 8.9.3 or above"); - return; - } - - if (check_device_arch_newer_than("ampere") == false) { - SKIP("Test requires Hopper or above arch."); - return; - } - - int64_t b = 3; // batch size - int64_t h = 4; // head dim - int64_t s_q = 1024; // q tensor is padded to this seq length - int64_t s_kv = 1024; // k and v tensor is padded to this seq length - int64_t d = 128; // hidden dim - - bool is_bias = true; - float dropout_probability = 0.2f; - - namespace fe = cudnn_frontend; - fe::graph::Graph mha_graph; - mha_graph.set_io_data_type(fe::DataType_t::HALF) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - // used for bias, and dropout != 0.0f - std::shared_ptr bias, dropout_seed, dropout_offset; - - auto q = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("Q").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); - auto k = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, h, s_kv, d}) - .set_stride({h * s_kv * d, s_kv * d, d, 1})); - auto v = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, h, s_kv, d}) - .set_stride({h * s_kv * d, s_kv * d, d, 1})); - auto o = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("O").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); - auto dO = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("dO").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); - auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - auto attn_scale = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - if (is_bias) { - bias = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({b, 1, s_q, s_kv}) - .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); - } - - if (dropout_probability != 0.0f) { - dropout_seed = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - dropout_offset = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - } - - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(true) - .set_attn_scale(attn_scale); - - if (is_bias) { - sdpa_backward_options.set_bias(bias); - } - - if (dropout_probability != 0.0f) { - sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); - } - - auto [dQ, dK, dV] = mha_graph.sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); - - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1}); - dK->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); - dV->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - REQUIRE(mha_graph.validate().is_good()); - - REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - - auto plans = mha_graph.create_execution_plans({fe::HeurMode_t::A}); - - REQUIRE(mha_graph.check_support(handle).is_good()); - - REQUIRE(mha_graph.build_plans(handle).is_good()); - - // build variant pack - // inputs - Surface q_tensor(b * h * s_q * d, false); - Surface k_tensor(b * h * d * s_kv, false); - Surface v_tensor(b * h * d * s_kv, false); - Surface o_tensor(b * h * s_q * d, false); - Surface dO_tensor(b * h * s_q * d, false); - Surface stats_tensor(b * h * s_q * 1, false); - // outputs - Surface dQ_tensor(b * h * s_q * d, false); - Surface dK_tensor(b * h * s_kv * d, false); - Surface dV_tensor(b * h * s_kv * d, false); - - float attn_scale_cpu = 0.5f; - - Surface bias_tensor(b * 1 * s_q * s_kv, false); - - int32_t seed_value = 123456; - int32_t offset_value = 789; - Surface dropout_seed_tensor(1, false, seed_value); - Surface dropout_offset_tensor(1, false, offset_value); - - std::unordered_map, void*> variant_pack = { - // inputs - {q, q_tensor.devPtr}, - {k, k_tensor.devPtr}, - {v, v_tensor.devPtr}, - {o, o_tensor.devPtr}, - {dO, dO_tensor.devPtr}, - {stats, stats_tensor.devPtr}, - // outputs - {dQ, dQ_tensor.devPtr}, - {dK, dK_tensor.devPtr}, - {dV, dV_tensor.devPtr}, - // pass by value - {attn_scale, &attn_scale_cpu}}; - - if (is_bias) { - variant_pack[bias] = bias_tensor.devPtr; - } - - if (dropout_probability != 0.0f) { - variant_pack[dropout_seed] = dropout_seed_tensor.devPtr; - variant_pack[dropout_offset] = dropout_offset_tensor.devPtr; - } - - Surface workspace(mha_graph.get_workspace_size(), false); - REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} - -TEST_CASE("sdpa_fp8_fprop", "[graph][mha][fp8][forward]") { - namespace fe = cudnn_frontend; - -#if CUDART_VERSION < 12000 - SKIP("Test requires cuda toolkit 12.0 or above"); - return; -#endif - - int64_t b = 2; // batch size - int64_t h = 2; // head dim - int64_t s = 512; // q,k,v tensor is padded to this seq length - int64_t d = 128; // hidden dim - - bool is_inference = false; - - fe::graph::Graph mha_graph; - mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); - - auto QKV_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd - auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bhsd - - auto Q = - mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); - auto K = - mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); - auto V = - mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); - - float attn_scale = 0.123f; - - auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Descale_Q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); - auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); - auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); - auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); - auto scale_o = mha_graph.tensor_like(descale_q, "Scale_O"); - - auto sdpa_fp8_options = fe::graph::SDPA_fp8_attributes() - .set_name("sdpa_fp8") - .set_is_inference(is_inference) - .set_causal_mask(true) - .set_attn_scale(attn_scale); - - auto [O, Stats, Amax_S, Amax_O] = - mha_graph.sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_fp8_options); - - O->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - Amax_O->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - Amax_S->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - - // Check that Stats tensor is real, which is only when its training step - if (is_inference) { - REQUIRE(Stats == nullptr); - } else { - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); - } - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - auto status = mha_graph.validate(); - if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { - REQUIRE(status.is_good()); - } else { - REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); - cudnnDestroy(handle); - return; - } - - REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - auto plans = mha_graph.create_execution_plans({fe::HeurMode_t::A}); - REQUIRE(mha_graph.check_support(handle).is_good()); - REQUIRE(mha_graph.build_plans(handle).is_good()); - - //// Build variant pack - Surface qkvTensor(b * s * 3 * h * d, false); - Surface oTensor(b * s * h * d, false); - void* devPtrQ = qkvTensor.devPtr; - void* devPtrK = (qkvTensor.devPtr + h * d); - void* devPtrV = (qkvTensor.devPtr + 2 * h * d); - void* devPtrO = oTensor.devPtr; - - Surface descale_Q_Tensor(1, false); - Surface descale_K_Tensor(1, false); - Surface descale_V_Tensor(1, false); - Surface descale_S_Tensor(1, false); - Surface scale_S_Tensor(1, false); - Surface scale_O_Tensor(1, false); - Surface Amax_S_Tensor(1, false); - Surface Amax_O_Tensor(1, false); - - std::unordered_map, void*> variant_pack = { - {Q, devPtrQ}, - {K, devPtrK}, - {V, devPtrV}, - {O, devPtrO}, - {descale_q, descale_Q_Tensor.devPtr}, - {descale_k, descale_K_Tensor.devPtr}, - {descale_v, descale_V_Tensor.devPtr}, - {descale_s, descale_S_Tensor.devPtr}, - {scale_s, scale_S_Tensor.devPtr}, - {scale_o, scale_O_Tensor.devPtr}, - {Amax_S, Amax_S_Tensor.devPtr}, - {Amax_O, Amax_O_Tensor.devPtr}}; - - Surface stats_tensor(b * h * s * 1, false); - if (is_inference == false) { - variant_pack[Stats] = stats_tensor.devPtr; - } - - Surface workspace(mha_graph.get_workspace_size(), false); - REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} - -TEST_CASE("sdpa_fp8_bprop", "[graph][mha][fp8][backward]") { - namespace fe = cudnn_frontend; - -#if CUDART_VERSION < 12000 - SKIP("Test requires cuda toolkit 12.0 or above"); - return; -#endif - - int64_t b = 2; // batch size - int64_t h = 2; // head dim - int64_t s = 512; // q,k,v tensor is padded to this seq length - int64_t d = 128; // hidden dim - - // bs3hd - auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); - // QKV_strides - auto Q_dQ_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd - - auto Q_K_V_dQ_dK_dV_bulk_strides = std::vector({s * 3 * h * d, 3 * h * d, h * d, d, 1}); - - auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bshd - - auto K_V_dK_dV_dims{Q_dQ_O_dO_dims}; - auto K_V_dK_dV_strides{Q_dQ_strides}; - - auto MZ_OdO_dims = std::vector({b, h, s, 1}); - auto MZ_OdO_strides = std::vector({h * s, s, 1, 1}); - - fe::graph::Graph mha_graph; - mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto Q = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("Q").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); - auto K = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("K").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); - auto V = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("V").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); - auto O = - mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); - auto dO = mha_graph.tensor( - fe::graph::Tensor_attributes().set_name("dO").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); - auto Stats = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Stats") - .set_dim(MZ_OdO_dims) - .set_stride(MZ_OdO_strides) - .set_data_type(fe::DataType_t::FLOAT)); - - float attn_scale = 0.123f; - - auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Descale_Q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); - auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); - auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); - auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); - auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); - auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); - - auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); - auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); - auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); - auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); - auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); - - // options/attributes - auto sdpa_fp8_backwards_options = fe::graph::SDPA_fp8_backward_attributes() - .set_name("sdpa_fp8_backward") - .set_causal_mask(true) - .set_attn_scale(attn_scale); - - // output - auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = mha_graph.sdpa_fp8_backward(Q, - K, - V, - O, - dO, - Stats, - descale_q, - descale_k, - descale_v, - descale_o, - descale_dO, - descale_s, - descale_dP, - scale_s, - scale_dQ, - scale_dK, - scale_dV, - scale_dP, - sdpa_fp8_backwards_options); - - dQ->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - dK->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - dV->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); - Amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - Amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - Amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - Amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - auto status = mha_graph.validate(); - if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { - REQUIRE(status.is_good()); - } else { - REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); - cudnnDestroy(handle); - return; - } - - REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); - REQUIRE(mha_graph.check_support(handle).is_good()); - REQUIRE(mha_graph.build_plans(handle).is_good()); - - // Surfaces - auto Q_K_V_dQ_dK_dV_bulk_dims{b * s * 3 * h * d}; - auto dO_O_dims{b * s * h * d}; - Surface qkvTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; - void* devPtrQ{qkvTensor.devPtr}; - void* devPtrK{qkvTensor.devPtr + h * d}; - void* devPtrV{qkvTensor.devPtr + 2 * h * d}; - - Surface dQdKdVTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; - void* devPtrdQ{dQdKdVTensor.devPtr}; - void* devPtrdK{dQdKdVTensor.devPtr + h * d}; - void* devPtrdV{dQdKdVTensor.devPtr + 2 * h * d}; - - Surface dOTensor{dO_O_dims, false}; - Surface OTensor{dO_O_dims, false}; - - Surface descale_Q_Tensor{1, false}; - Surface descale_K_Tensor{1, false}; - Surface descale_V_Tensor{1, false}; - Surface descale_S_Tensor{1, false}; - Surface descale_dP_Tensor{1, false}; - Surface descale_dO_Tensor{1, false}; - Surface descale_O_Tensor{1, false}; - - Surface scale_S_Tensor{1, false}; - Surface scale_dQ_Tensor{1, false}; - Surface scale_dK_Tensor{1, false}; - Surface scale_dV_Tensor{1, false}; - Surface scale_dP_Tensor{1, false}; - - Surface AMax_dQ_Tensor{1, false}; - Surface AMax_dK_Tensor{1, false}; - Surface AMax_dV_Tensor{1, false}; - Surface AMax_dP_Tensor{1, false}; - - Surface StatsTensor(b * h * s * 1, false); - - // Variant pack - std::unordered_map, void*> variant_pack{ - {Q, devPtrQ}, - {K, devPtrK}, - {V, devPtrV}, - {O, OTensor.devPtr}, - {dO, dOTensor.devPtr}, - {dQ, devPtrdQ}, - {dK, devPtrdK}, - {dV, devPtrdV}, - {descale_q, descale_Q_Tensor.devPtr}, - {descale_k, descale_K_Tensor.devPtr}, - {descale_v, descale_V_Tensor.devPtr}, - {descale_o, descale_O_Tensor.devPtr}, - {descale_dO, descale_dO_Tensor.devPtr}, - {descale_s, descale_S_Tensor.devPtr}, - {descale_dP, descale_dP_Tensor.devPtr}, - {scale_s, scale_S_Tensor.devPtr}, - {scale_dQ, scale_dQ_Tensor.devPtr}, - {scale_dK, scale_dK_Tensor.devPtr}, - {scale_dV, scale_dV_Tensor.devPtr}, - {scale_dP, scale_dP_Tensor.devPtr}, - {Stats, StatsTensor.devPtr}, - {Amax_dQ, AMax_dQ_Tensor.devPtr}, - {Amax_dK, AMax_dK_Tensor.devPtr}, - {Amax_dV, AMax_dV_Tensor.devPtr}, - {Amax_dP, AMax_dP_Tensor.devPtr}}; - - Surface workspace(mha_graph.get_workspace_size(), false); - REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} - -TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][mha][fp8][backward]") { - namespace fe = cudnn_frontend; - -#if CUDART_VERSION < 12000 - SKIP("Test requires cuda toolkit 12.0 or above"); - return; -#endif - - int64_t b = 2; // batch size - int64_t h_qo = 12; // query/output head dim - int64_t h_kv = 4; // key/value head dim - int64_t s = 512; // q,k,v tensor is padded to this seq length - int64_t d = 128; // hidden dim - - // construct graph - std::vector qo_dim = {b, h_qo, s, d}; - std::vector kv_dim = {b, h_kv, s, d}; - std::vector qo_stride = {s * h_qo * d, d, h_qo * d, 1}; // bshd - std::vector kv_stride = {s * h_kv * d, d, h_kv * d, 1}; // bshd - - std::vector stats_dim = {b, h_qo, s, 1}; - std::vector stats_stride = {h_qo * s, s, 1, 1}; - - fe::graph::Graph mha_graph; - mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - - auto q = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(qo_dim).set_stride(qo_stride)); - auto k = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(kv_dim).set_stride(kv_stride)); - auto v = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(kv_dim).set_stride(kv_stride)); - auto o = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(qo_dim).set_stride(qo_stride)); - auto dO = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("dO").set_dim(qo_dim).set_stride(qo_stride)); - auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Stats") - .set_dim(stats_dim) - .set_stride(stats_stride) - .set_data_type(fe::DataType_t::FLOAT)); - - float attn_scale = 0.125f; - - auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Descale_Q") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); - auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); - auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); - auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); - auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); - auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); - - auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); - auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); - auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); - auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); - auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); - - // clang-format off - auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph.sdpa_fp8_backward( - q, k, v, o, dO, stats, - descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, - scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, - fe::graph::SDPA_fp8_backward_attributes().set_name("sdpa_fp8_backward") - .set_causal_mask(true) - .set_attn_scale(attn_scale) - ); - // clang-format on - - dQ->set_output(true).set_dim(qo_dim).set_stride(qo_stride); - dK->set_output(true).set_dim(kv_dim).set_stride(kv_stride); - dV->set_output(true).set_dim(kv_dim).set_stride(kv_stride); - amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - - cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); - - auto status = mha_graph.validate(); - if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { - REQUIRE(status.is_good()); - } else { - REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); - cudnnDestroy(handle); - return; - } - - REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); - REQUIRE(mha_graph.check_support(handle).is_good()); - REQUIRE(mha_graph.build_plans(handle).is_good()); - - // Surfaces that alllocate GPU memory - Surface q_gpu(b * s * h_qo * d, false); - Surface k_gpu(b * s * h_kv * d, false); - Surface v_gpu(b * s * h_kv * d, false); - Surface o_gpu(b * s * h_qo * d, false); - - Surface stats_gpu(b * h_qo * s * 1, false); - - Surface dQ_gpu(b * s * h_qo * d, false); - Surface dK_gpu(b * s * h_kv * d, false); - Surface dV_gpu(b * s * h_kv * d, false); - Surface dO_gpu(b * s * h_qo * d, false); - - Surface descale_q_gpu(1, false); - Surface descale_k_gpu(1, false); - Surface descale_v_gpu(1, false); - Surface descale_o_gpu(1, false); - Surface descale_s_gpu(1, false); - Surface descale_dP_gpu(1, false); - Surface descale_dO_gpu(1, false); - - Surface scale_s_gpu(1, false); - Surface scale_dQ_gpu(1, false); - Surface scale_dK_gpu(1, false); - Surface scale_dV_gpu(1, false); - Surface scale_dP_gpu(1, false); - - Surface amax_dQ_gpu(1, false); - Surface amax_dK_gpu(1, false); - Surface amax_dV_gpu(1, false); - Surface amax_dP_gpu(1, false); - - // Variant pack - std::unordered_map, void*> variant_pack{ - {q, q_gpu.devPtr}, - {k, k_gpu.devPtr}, - {v, v_gpu.devPtr}, - {o, o_gpu.devPtr}, - - {dQ, dQ_gpu.devPtr}, - {dK, dK_gpu.devPtr}, - {dV, dV_gpu.devPtr}, - {dO, dO_gpu.devPtr}, - - {stats, stats_gpu.devPtr}, - - {descale_q, descale_q_gpu.devPtr}, - {descale_k, descale_k_gpu.devPtr}, - {descale_v, descale_v_gpu.devPtr}, - {descale_o, descale_o_gpu.devPtr}, - {descale_s, descale_s_gpu.devPtr}, - {descale_dP, descale_dP_gpu.devPtr}, - {descale_dO, descale_dO_gpu.devPtr}, - - {scale_s, scale_s_gpu.devPtr}, - {scale_dQ, scale_dQ_gpu.devPtr}, - {scale_dK, scale_dK_gpu.devPtr}, - {scale_dV, scale_dV_gpu.devPtr}, - {scale_dP, scale_dP_gpu.devPtr}, - - {amax_dQ, amax_dQ_gpu.devPtr}, - {amax_dK, amax_dK_gpu.devPtr}, - {amax_dV, amax_dV_gpu.devPtr}, - {amax_dP, amax_dP_gpu.devPtr}}; - - Surface workspace(mha_graph.get_workspace_size(), false); - REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - - checkCudaErr(cudaDeviceSynchronize()); - - cudnnDestroy(handle); -} \ No newline at end of file diff --git a/samples/cpp/autotuning.cpp b/samples/cpp/misc/autotuning.cpp similarity index 95% rename from samples/cpp/autotuning.cpp rename to samples/cpp/misc/autotuning.cpp index 32a91f1..4e52e11 100644 --- a/samples/cpp/autotuning.cpp +++ b/samples/cpp/misc/autotuning.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include @@ -77,8 +77,12 @@ TEST_CASE("Matmul autotuning", "[matmul][graph][autotuning]") { REQUIRE(graph.build_operation_graph(handle).is_good()); + graph.deselect_workspace_greater_than(0); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + graph.deselect_workspace_greater_than(1024 * 1024); + REQUIRE(graph.check_support(handle).is_good()); return graph; @@ -86,13 +90,12 @@ TEST_CASE("Matmul autotuning", "[matmul][graph][autotuning]") { auto graph = create_graph(); - graph.deselect_workspace_greater_than(0); auto plan_count = graph.get_execution_plan_count(); std::cout << "Graph has " << plan_count << " plan candidates." << std::endl; REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good()); - std::unordered_map variant_pack = { + std::unordered_map variant_pack = { {a_uid, A_gpu.devPtr}, {b_uid, B_gpu.devPtr}, {c_uid, C_gpu.devPtr}}; auto autotune = [&]() -> int64_t { diff --git a/samples/cpp/misc/parallel_compilation.cpp b/samples/cpp/misc/parallel_compilation.cpp new file mode 100644 index 0000000..99dc410 --- /dev/null +++ b/samples/cpp/misc/parallel_compilation.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include +#include +#include + +#include "../../utils/helpers.h" + +#include + +#include +#include +#include +#include + +TEST_CASE("Parallel build", "[matmul][graph][parallel]") { + SKIP( + "Very long test turned off by default. Run /bin/samples --benchmark-samples 1 \"Parallel build\" after " + "uncommenting this line."); + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by currend cudnn version"); + } + namespace fe = cudnn_frontend; + + // matmul problem size + int64_t const b = 16; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 128; + + // Initialize input tensors + Surface A_gpu(b * m * k, false); + Surface B_gpu(b * k * n, false); + Surface C_gpu(b * m * n, false); + + int64_t a_uid = 0, b_uid = 1, c_uid = 2; + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto create_graph = [&]() -> fe::graph::Graph { + // Make cudnn graph + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_uid(a_uid) + .set_data_type(fe::DataType_t::BFLOAT16); + auto A = graph.tensor(A_attributes); + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, n, 1}) + .set_uid(b_uid) + .set_data_type(fe::DataType_t::BFLOAT16); + auto B = graph.tensor(B_attributes); + + auto matmul_attributes = + fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A, B, matmul_attributes); + C->set_output(true).set_uid(c_uid).set_data_type(fe::DataType_t::BFLOAT16); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + graph.select_behavior_notes({fe::BehaviorNote_t::RUNTIME_COMPILATION}); + + REQUIRE(graph.check_support(handle).is_good()); + + return graph; + }; + + auto build = [](fe::graph::Graph &graph, cudnnHandle_t handle, int index) { + auto status = graph.build_plan_at_index(handle, index); + }; + + BENCHMARK("BuildPlanPolicy_t::HEURISTICS_CHOICE") { + fe::graph::Graph graph = create_graph(); + return graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good(); + }; + + BENCHMARK("BuildPlanPolicy_t::ALL") { + fe::graph::Graph graph = create_graph(); + return graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good(); + }; + + BENCHMARK("build_plan_at_index::ALL::serial") { + fe::graph::Graph graph = create_graph(); + auto plan_count = graph.get_execution_plan_count(); + for (auto i = 0; i < plan_count; i++) { + build(graph, handle, i); + } + }; + + BENCHMARK("build_plan_at_index::ALL::parallel") { + fe::graph::Graph graph = create_graph(); + auto plan_count = graph.get_execution_plan_count(); + std::vector builders; + for (auto i = 0; i < plan_count; i++) { + builders.emplace_back(std::thread{build, std::reference_wrapper(graph), handle, i}); + } + for (auto &builder : builders) { + builder.join(); + } + }; + + { + auto input = GENERATE(range(2, 11)); + + BENCHMARK("build_plan_at_index::ALL::parallel_" + std::to_string(input)) { + fe::graph::Graph graph = create_graph(); + auto plan_count = input < graph.get_execution_plan_count() ? input : graph.get_execution_plan_count(); + std::vector builders; + for (auto i = 0; i < plan_count; i++) { + builders.emplace_back(std::thread{build, std::reference_wrapper(graph), handle, i}); + } + for (auto &builder : builders) { + builder.join(); + } + }; + } + + checkCudnnErr(cudnnDestroy(handle)); +} \ No newline at end of file diff --git a/samples/cpp/pointwise.cpp b/samples/cpp/misc/pointwise.cpp similarity index 99% rename from samples/cpp/pointwise.cpp rename to samples/cpp/misc/pointwise.cpp index 7ecbac1..b3b1e05 100644 --- a/samples/cpp/pointwise.cpp +++ b/samples/cpp/misc/pointwise.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/resample.cpp b/samples/cpp/misc/resample.cpp similarity index 99% rename from samples/cpp/resample.cpp rename to samples/cpp/misc/resample.cpp index 538b2d6..a13f065 100644 --- a/samples/cpp/resample.cpp +++ b/samples/cpp/misc/resample.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/serialization.cpp b/samples/cpp/misc/serialization.cpp similarity index 99% rename from samples/cpp/serialization.cpp rename to samples/cpp/misc/serialization.cpp index 3265138..267a6ed 100644 --- a/samples/cpp/serialization.cpp +++ b/samples/cpp/misc/serialization.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/batchnorm.cpp b/samples/cpp/norm/batchnorm.cpp similarity index 99% rename from samples/cpp/batchnorm.cpp rename to samples/cpp/norm/batchnorm.cpp index 9bb7e3d..e0fcc1c 100644 --- a/samples/cpp/batchnorm.cpp +++ b/samples/cpp/norm/batchnorm.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/layernorm.cpp b/samples/cpp/norm/layernorm.cpp similarity index 99% rename from samples/cpp/layernorm.cpp rename to samples/cpp/norm/layernorm.cpp index 3087aaa..2cd5adf 100644 --- a/samples/cpp/layernorm.cpp +++ b/samples/cpp/norm/layernorm.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/rmsnorm.cpp b/samples/cpp/norm/rmsnorm.cpp similarity index 99% rename from samples/cpp/rmsnorm.cpp rename to samples/cpp/norm/rmsnorm.cpp index 68d8001..3b8ef52 100644 --- a/samples/cpp/rmsnorm.cpp +++ b/samples/cpp/norm/rmsnorm.cpp @@ -21,7 +21,7 @@ */ #include -#include "../utils/helpers.h" +#include "../../utils/helpers.h" #include diff --git a/samples/cpp/sdpa/fp16_bwd.cpp b/samples/cpp/sdpa/fp16_bwd.cpp new file mode 100644 index 0000000..857595d --- /dev/null +++ b/samples/cpp/sdpa/fp16_bwd.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +/* +Run this example by using command: +bin/samples "Toy sdpa backward" + +This example shows how to construct a sdpa backward graph-> +*/ + +// Tensors in backward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define SEQ_LEN_Q_UID 7 +#define SEQ_LEN_KV_UID 8 + +#define DO_UID 101 +#define DQ_UID 102 +#define DK_UID 103 +#define DV_UID 104 + +// Function to create the SDPA (Scaled Dot-Product Attention) backward graph +std::shared_ptr +create_sdpa_backward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + [[maybe_unused]] bool const is_inference = false, + bool const causal_mask = false, + bool const alibi_mask = false, + bool const padding_mask = false, + bool has_attn_bias = false) { + // Create a graph and set common global properties + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // Define input tensors Q, K, V + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1})); + + // Define output tensor O + auto O = graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_uid(O_UID) + .set_dim({b, h_q, s_q, d_v}) + .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})); + + // Define gradient tensor dO + auto dO = graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_uid(DO_UID) + .set_dim({b, h_q, s_q, d_v}) + .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})); + + // Define stats tensor + auto Stats = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_uid(STATS_UID) + .set_dim({b, h_q, s_q, 1}) + .set_stride({h_q * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + // Set SDPA backward options + auto sdpa_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_alibi_mask(alibi_mask) + .set_causal_mask(causal_mask) + .set_attn_scale(attn_scale); + + // If attention bias is provided, set it + if (has_attn_bias) { + auto bias = graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_uid(BIAS_UID) + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + sdpa_options.set_bias(bias); + } + + // If padding mask is enabled, set sequence lengths + if (padding_mask) { + auto seq_q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_uid(SEQ_LEN_Q_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto seq_kv = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_uid(SEQ_LEN_KV_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(padding_mask).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } + + // Compute SDPA backward and get gradients dQ, dK, dV + auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, Stats, sdpa_options); + + // Set output tensors dQ, dK, dV + dQ->set_output(true) + .set_uid(DQ_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}); + dK->set_output(true) + .set_uid(DK_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1}); + dV->set_output(true) + .set_uid(DV_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1}); + + return graph; +} + +// Test case for the SDPA backward graph +TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { + int64_t b = 3; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + bool is_inference = false; + float attn_scale = 0.123f; + bool causal_mask = true; + bool padding_mask = (cudnnGetVersion() >= 8903); + bool alibi_mask = (cudnnGetVersion() >= 8904); + bool has_attn_bias = (cudnnGetVersion() >= 8903); + + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cudnn 8.9.3 or above"); + return; + } + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + // Create the SDPA backward graph + auto graph = create_sdpa_backward_graph(b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + attn_scale, + is_inference, + causal_mask, + alibi_mask, + padding_mask, + has_attn_bias); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + //// Build variant pack + // inputs + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + Surface o_tensor(b * h_q * s_q * d_v, false); + Surface dO_tensor(b * h_q * s_q * d_v, false); + Surface stats_tensor(b * h_q * s_q * 1, false); + // outputs + Surface dQ_tensor(b * h_q * s_q * d_qk, false); + Surface dK_tensor(b * h_k * s_kv * d_qk, false); + Surface dV_tensor(b * h_v * s_kv * d_v, false); + + Surface bias_tensor(b * 1 * s_q * s_kv, false); + + // Create variant pack with input and output tensors + std::unordered_map variant_pack = {// inputs + {Q_UID, q_tensor.devPtr}, + {K_UID, k_tensor.devPtr}, + {V_UID, v_tensor.devPtr}, + {O_UID, o_tensor.devPtr}, + {DO_UID, dO_tensor.devPtr}, + {STATS_UID, stats_tensor.devPtr}, + // outputs + {DQ_UID, dQ_tensor.devPtr}, + {DK_UID, dK_tensor.devPtr}, + {DV_UID, dV_tensor.devPtr}}; + + // If attention bias is provided, add it to the variant pack + if (has_attn_bias) { + variant_pack[BIAS_UID] = bias_tensor.devPtr; + } + + // If padding mask is enabled, add sequence lengths to the variant pack + Surface devActualSeqlenQ(b, false); + Surface devActualSeqlenKV(b, false); + if (padding_mask) { + std::vector hostActualSeqlenQ(b, 20); + std::vector hostActualSeqlenKV(b, 20); + + checkCudaErr(cudaMemcpy(devActualSeqlenQ.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * b, + cudaMemcpyHostToDevice)); + checkCudaErr(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + checkCudaErr(cudaDeviceSynchronize()); + + variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; + variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; + } + + // Allocate workspace + Surface workspace(graph->get_workspace_size(), false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} diff --git a/samples/cpp/sdpa/fp16_cached.cpp b/samples/cpp/sdpa/fp16_cached.cpp new file mode 100644 index 0000000..570dcc6 --- /dev/null +++ b/samples/cpp/sdpa/fp16_cached.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +/* +Run this example by using command: +bin/samples "Cached sdpa" + +This example is supposed to be used when executing full models and/or doing multiple iterations. +*/ + +// Directly use the forward graph builder from the toy example +std::shared_ptr +create_sdpa_forward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + bool const is_inference = false, + bool const causal_mask = false, + bool const alibi_mask = false, + bool const padding_mask = false, + bool has_attn_bias = false); + +// Directly use the backward graph builder from the toy example +std::shared_ptr +create_sdpa_backward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + bool const is_inference = false, + bool const causal_mask = false, + bool const alibi_mask = false, + bool const padding_mask = false, + bool has_attn_bias = false); + +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define SEQ_LEN_Q_UID 7 +#define SEQ_LEN_KV_UID 8 + +#define DO_UID 101 +#define DQ_UID 102 +#define DK_UID 103 +#define DV_UID 104 + +using cache_t = std::unordered_map>; +cache_t user_maintained_cache; + +bool +cache_lookup_pre_built_graph(std::shared_ptr& graph, cudnnHandle_t handle) { + auto cache_key = graph->key(); + if (auto it = user_maintained_cache.find(cache_key); it != user_maintained_cache.end()) { + graph = it->second; + return true; + } + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + user_maintained_cache.emplace(cache_key, graph); + return false; +} + +TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { + int64_t b = 3; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cudnn 8.9.3 or above"); + return; + } + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto fwd_graph = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); + auto bwd_graph = create_sdpa_backward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); + + // Wont get a cache hit the first time + REQUIRE(cache_lookup_pre_built_graph(fwd_graph, handle) == false); + REQUIRE(cache_lookup_pre_built_graph(bwd_graph, handle) == false); + + auto fwd_graph2 = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); + auto bwd_graph2 = create_sdpa_backward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); + + REQUIRE(cache_lookup_pre_built_graph(fwd_graph2, handle) == true); + REQUIRE(cache_lookup_pre_built_graph(bwd_graph2, handle) == true); + + //// Build variant pack + std::unordered_map variant_pack; + // inputs + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + + Surface o_tensor(b * h_q * s_q * d_qk, false); + Surface stats_tensor(b * h_q * s_q * 1, false); + + variant_pack = {{Q_UID, q_tensor.devPtr}, + {K_UID, k_tensor.devPtr}, + {V_UID, v_tensor.devPtr}, + {O_UID, o_tensor.devPtr}, + {STATS_UID, stats_tensor.devPtr}}; + + Surface fwd_workspace(fwd_graph2->get_workspace_size(), false); + REQUIRE(fwd_graph2->execute(handle, variant_pack, fwd_workspace.devPtr).is_good()); + checkCudaErr(cudaDeviceSynchronize()); + + Surface dO_tensor(b * h_q * s_q * d_qk, false); + Surface dQ_tensor(b * h_q * s_q * d_qk, false); + Surface dK_tensor(b * h_k * s_kv * d_qk, false); + Surface dV_tensor(b * h_v * s_kv * d_v, false); + + variant_pack = {// inputs + {Q_UID, q_tensor.devPtr}, + {K_UID, k_tensor.devPtr}, + {V_UID, v_tensor.devPtr}, + {O_UID, o_tensor.devPtr}, + {DO_UID, dO_tensor.devPtr}, + {STATS_UID, stats_tensor.devPtr}, + // outputs + {DQ_UID, dQ_tensor.devPtr}, + {DK_UID, dK_tensor.devPtr}, + {DV_UID, dV_tensor.devPtr}}; + Surface bwd_workspace(bwd_graph2->get_workspace_size(), false); + REQUIRE(bwd_graph2->execute(handle, variant_pack, bwd_workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} diff --git a/samples/cpp/sdpa/fp16_fwd.cpp b/samples/cpp/sdpa/fp16_fwd.cpp new file mode 100644 index 0000000..d8b6f24 --- /dev/null +++ b/samples/cpp/sdpa/fp16_fwd.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +/* +Run this example by using command: +bin/samples "Toy sdpa forward" + +This example shows how to construct a sdpa forward graph. +*/ + +// Tensors in forward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define SEQ_LEN_Q_UID 7 +#define SEQ_LEN_KV_UID 8 + +std::shared_ptr +create_sdpa_forward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + bool const is_inference = false, + bool const causal_mask = false, + bool const alibi_mask = false, + bool const padding_mask = false, + bool has_attn_bias = false) { + // Create a graph and set common global properties. + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1})); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(is_inference) + .set_alibi_mask(alibi_mask) + .set_causal_mask(causal_mask) + .set_attn_scale(attn_scale); + + if (has_attn_bias) { + auto bias = graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_uid(BIAS_UID) + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + sdpa_options.set_bias(bias); + } + + if (padding_mask) { + auto seq_q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_uid(SEQ_LEN_Q_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto seq_kv = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_uid(SEQ_LEN_KV_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(padding_mask).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_dim({b, h_q, s_q, d_v}).set_stride({h_q * d_v, d_v, b * h_q * d_v, 1}).set_uid(O_UID); + + if (is_inference) { + assert(Stats == nullptr); + } else { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_uid(STATS_UID); + } + + return graph; +} + +TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { + int64_t b = 3; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + bool is_inference = false; + float attn_scale = 0.123f; + bool causal_mask = true; + bool padding_mask = (cudnnGetVersion() >= 8903); + bool alibi_mask = (cudnnGetVersion() >= 8904); + bool has_attn_bias = (cudnnGetVersion() >= 8903); + + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cudnn 8.9.3 or above"); + return; + } + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto graph = create_sdpa_forward_graph(b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + attn_scale, + is_inference, + causal_mask, + alibi_mask, + padding_mask, + has_attn_bias); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + //// Build variant pack + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + + Surface o_tensor(b * s_q * h_q * d_qk, false); + + std::unordered_map variant_pack = { + {Q_UID, q_tensor.devPtr}, {K_UID, k_tensor.devPtr}, {V_UID, v_tensor.devPtr}, {O_UID, o_tensor.devPtr}}; + + Surface bias_tensor(b * 1 * s_q * s_kv, false); + if (has_attn_bias) { + variant_pack[BIAS_UID] = bias_tensor.devPtr; + } + + Surface devActualSeqlenQ(b, false); + Surface devActualSeqlenKV(b, false); + if (padding_mask) { + std::vector hostActualSeqlenQ(b, 20); + std::vector hostActualSeqlenKV(b, 20); + + checkCudaErr(cudaMemcpy(devActualSeqlenQ.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * b, + cudaMemcpyHostToDevice)); + checkCudaErr(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + checkCudaErr(cudaDeviceSynchronize()); + + variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; + variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; + } + + Surface statsTensor(b * h_q * s_q * 1, false); + if (is_inference == false) { + variant_pack[STATS_UID] = statsTensor.devPtr; + } + + Surface workspace(graph->get_workspace_size(), false); + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} diff --git a/samples/cpp/sdpa/fp8_bwd.cpp b/samples/cpp/sdpa/fp8_bwd.cpp new file mode 100644 index 0000000..6181517 --- /dev/null +++ b/samples/cpp/sdpa/fp8_bwd.cpp @@ -0,0 +1,391 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include +#include + +namespace fe = cudnn_frontend; + +TEST_CASE("sdpa_fp8_bprop", "[graph][sdpa][fp8][backward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h = 2; // head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + // bs3hd + auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); + // QKV_strides + auto Q_dQ_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd + + auto Q_K_V_dQ_dK_dV_bulk_strides = std::vector({s * 3 * h * d, 3 * h * d, h * d, d, 1}); + + auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bshd + + auto K_V_dK_dV_dims{Q_dQ_O_dO_dims}; + auto K_V_dK_dV_strides{Q_dQ_strides}; + + auto MZ_OdO_dims = std::vector({b, h, s, 1}); + auto MZ_OdO_strides = std::vector({h * s, s, 1, 1}); + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("Q").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto K = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("K").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto V = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("V").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto O = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); + auto dO = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("dO").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); + auto Stats = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(MZ_OdO_dims) + .set_stride(MZ_OdO_strides) + .set_data_type(fe::DataType_t::FLOAT)); + + float attn_scale = 0.123f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); + + // options/attributes + auto sdpa_fp8_backwards_options = fe::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + // output + auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = mha_graph.sdpa_fp8_backward(Q, + K, + V, + O, + dO, + Stats, + descale_q, + descale_k, + descale_v, + descale_o, + descale_dO, + descale_s, + descale_dP, + scale_s, + scale_dQ, + scale_dK, + scale_dV, + scale_dP, + sdpa_fp8_backwards_options); + + dQ->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + dK->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + dV->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + Amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + // Surfaces + auto Q_K_V_dQ_dK_dV_bulk_dims{b * s * 3 * h * d}; + auto dO_O_dims{b * s * h * d}; + Surface qkvTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; + void* devPtrQ{qkvTensor.devPtr}; + void* devPtrK{qkvTensor.devPtr + h * d}; + void* devPtrV{qkvTensor.devPtr + 2 * h * d}; + + Surface dQdKdVTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; + void* devPtrdQ{dQdKdVTensor.devPtr}; + void* devPtrdK{dQdKdVTensor.devPtr + h * d}; + void* devPtrdV{dQdKdVTensor.devPtr + 2 * h * d}; + + Surface dOTensor{dO_O_dims, false}; + Surface OTensor{dO_O_dims, false}; + + Surface descale_Q_Tensor{1, false}; + Surface descale_K_Tensor{1, false}; + Surface descale_V_Tensor{1, false}; + Surface descale_S_Tensor{1, false}; + Surface descale_dP_Tensor{1, false}; + Surface descale_dO_Tensor{1, false}; + Surface descale_O_Tensor{1, false}; + + Surface scale_S_Tensor{1, false}; + Surface scale_dQ_Tensor{1, false}; + Surface scale_dK_Tensor{1, false}; + Surface scale_dV_Tensor{1, false}; + Surface scale_dP_Tensor{1, false}; + + Surface AMax_dQ_Tensor{1, false}; + Surface AMax_dK_Tensor{1, false}; + Surface AMax_dV_Tensor{1, false}; + Surface AMax_dP_Tensor{1, false}; + + Surface StatsTensor(b * h * s * 1, false); + + // Variant pack + std::unordered_map, void*> variant_pack{ + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, OTensor.devPtr}, + {dO, dOTensor.devPtr}, + {dQ, devPtrdQ}, + {dK, devPtrdK}, + {dV, devPtrdV}, + {descale_q, descale_Q_Tensor.devPtr}, + {descale_k, descale_K_Tensor.devPtr}, + {descale_v, descale_V_Tensor.devPtr}, + {descale_o, descale_O_Tensor.devPtr}, + {descale_dO, descale_dO_Tensor.devPtr}, + {descale_s, descale_S_Tensor.devPtr}, + {descale_dP, descale_dP_Tensor.devPtr}, + {scale_s, scale_S_Tensor.devPtr}, + {scale_dQ, scale_dQ_Tensor.devPtr}, + {scale_dK, scale_dK_Tensor.devPtr}, + {scale_dV, scale_dV_Tensor.devPtr}, + {scale_dP, scale_dP_Tensor.devPtr}, + {Stats, StatsTensor.devPtr}, + {Amax_dQ, AMax_dQ_Tensor.devPtr}, + {Amax_dK, AMax_dK_Tensor.devPtr}, + {Amax_dV, AMax_dV_Tensor.devPtr}, + {Amax_dP, AMax_dP_Tensor.devPtr}}; + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} + +TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][sdpa][fp8][backward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h_qo = 12; // query/output head dim + int64_t h_kv = 4; // key/value head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + // construct graph + std::vector qo_dim = {b, h_qo, s, d}; + std::vector kv_dim = {b, h_kv, s, d}; + std::vector qo_stride = {s * h_qo * d, d, h_qo * d, 1}; // bshd + std::vector kv_stride = {s * h_kv * d, d, h_kv * d, 1}; // bshd + + std::vector stats_dim = {b, h_qo, s, 1}; + std::vector stats_stride = {h_qo * s, s, 1, 1}; + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(qo_dim).set_stride(qo_stride)); + auto k = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(kv_dim).set_stride(kv_stride)); + auto v = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(kv_dim).set_stride(kv_stride)); + auto o = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(qo_dim).set_stride(qo_stride)); + auto dO = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("dO").set_dim(qo_dim).set_stride(qo_stride)); + auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(stats_dim) + .set_stride(stats_stride) + .set_data_type(fe::DataType_t::FLOAT)); + + float attn_scale = 0.125f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); + + // clang-format off + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph.sdpa_fp8_backward( + q, k, v, o, dO, stats, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + fe::graph::SDPA_fp8_backward_attributes().set_name("sdpa_fp8_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale) + ); + // clang-format on + + dQ->set_output(true).set_dim(qo_dim).set_stride(qo_stride); + dK->set_output(true).set_dim(kv_dim).set_stride(kv_stride); + dV->set_output(true).set_dim(kv_dim).set_stride(kv_stride); + amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + // Surfaces that alllocate GPU memory + Surface q_gpu(b * s * h_qo * d, false); + Surface k_gpu(b * s * h_kv * d, false); + Surface v_gpu(b * s * h_kv * d, false); + Surface o_gpu(b * s * h_qo * d, false); + + Surface stats_gpu(b * h_qo * s * 1, false); + + Surface dQ_gpu(b * s * h_qo * d, false); + Surface dK_gpu(b * s * h_kv * d, false); + Surface dV_gpu(b * s * h_kv * d, false); + Surface dO_gpu(b * s * h_qo * d, false); + + Surface descale_q_gpu(1, false); + Surface descale_k_gpu(1, false); + Surface descale_v_gpu(1, false); + Surface descale_o_gpu(1, false); + Surface descale_s_gpu(1, false); + Surface descale_dP_gpu(1, false); + Surface descale_dO_gpu(1, false); + + Surface scale_s_gpu(1, false); + Surface scale_dQ_gpu(1, false); + Surface scale_dK_gpu(1, false); + Surface scale_dV_gpu(1, false); + Surface scale_dP_gpu(1, false); + + Surface amax_dQ_gpu(1, false); + Surface amax_dK_gpu(1, false); + Surface amax_dV_gpu(1, false); + Surface amax_dP_gpu(1, false); + + // Variant pack + std::unordered_map, void*> variant_pack{ + {q, q_gpu.devPtr}, + {k, k_gpu.devPtr}, + {v, v_gpu.devPtr}, + {o, o_gpu.devPtr}, + + {dQ, dQ_gpu.devPtr}, + {dK, dK_gpu.devPtr}, + {dV, dV_gpu.devPtr}, + {dO, dO_gpu.devPtr}, + + {stats, stats_gpu.devPtr}, + + {descale_q, descale_q_gpu.devPtr}, + {descale_k, descale_k_gpu.devPtr}, + {descale_v, descale_v_gpu.devPtr}, + {descale_o, descale_o_gpu.devPtr}, + {descale_s, descale_s_gpu.devPtr}, + {descale_dP, descale_dP_gpu.devPtr}, + {descale_dO, descale_dO_gpu.devPtr}, + + {scale_s, scale_s_gpu.devPtr}, + {scale_dQ, scale_dQ_gpu.devPtr}, + {scale_dK, scale_dK_gpu.devPtr}, + {scale_dV, scale_dV_gpu.devPtr}, + {scale_dP, scale_dP_gpu.devPtr}, + + {amax_dQ, amax_dQ_gpu.devPtr}, + {amax_dK, amax_dK_gpu.devPtr}, + {amax_dV, amax_dV_gpu.devPtr}, + {amax_dP, amax_dP_gpu.devPtr}}; + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/cpp/sdpa/fp8_fwd.cpp b/samples/cpp/sdpa/fp8_fwd.cpp new file mode 100644 index 0000000..c2abb32 --- /dev/null +++ b/samples/cpp/sdpa/fp8_fwd.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../../utils/helpers.h" + +#include +#include + +namespace fe = cudnn_frontend; + +TEST_CASE("sdpa_fp8_fprop", "[graph][sdpa][fp8][forward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h = 2; // head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + bool is_inference = false; + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); + + auto QKV_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd + auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bhsd + + auto Q = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + auto K = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + auto V = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + + float attn_scale = 0.123f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_o = mha_graph.tensor_like(descale_q, "Scale_O"); + + auto sdpa_fp8_options = fe::graph::SDPA_fp8_attributes() + .set_name("sdpa_fp8") + .set_is_inference(is_inference) + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + auto [O, Stats, Amax_S, Amax_O] = + mha_graph.sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_fp8_options); + + O->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + Amax_O->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_S->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + // Check that Stats tensor is real, which is only when its training step + if (is_inference) { + REQUIRE(Stats == nullptr); + } else { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); + } + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + auto plans = mha_graph.create_execution_plans({fe::HeurMode_t::A}); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + //// Build variant pack + Surface qkvTensor(b * s * 3 * h * d, false); + Surface oTensor(b * s * h * d, false); + void* devPtrQ = qkvTensor.devPtr; + void* devPtrK = (qkvTensor.devPtr + h * d); + void* devPtrV = (qkvTensor.devPtr + 2 * h * d); + void* devPtrO = oTensor.devPtr; + + Surface descale_Q_Tensor(1, false); + Surface descale_K_Tensor(1, false); + Surface descale_V_Tensor(1, false); + Surface descale_S_Tensor(1, false); + Surface scale_S_Tensor(1, false); + Surface scale_O_Tensor(1, false); + Surface Amax_S_Tensor(1, false); + Surface Amax_O_Tensor(1, false); + + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {descale_q, descale_Q_Tensor.devPtr}, + {descale_k, descale_K_Tensor.devPtr}, + {descale_v, descale_V_Tensor.devPtr}, + {descale_s, descale_S_Tensor.devPtr}, + {scale_s, scale_S_Tensor.devPtr}, + {scale_o, scale_O_Tensor.devPtr}, + {Amax_S, Amax_S_Tensor.devPtr}, + {Amax_O, Amax_O_Tensor.devPtr}}; + + Surface stats_tensor(b * h * s * 1, false); + if (is_inference == false) { + variant_pack[Stats] = stats_tensor.devPtr; + } + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/legacy_samples/fusion_sample.cpp b/samples/legacy_samples/fusion_sample.cpp index d0f2f86..1e5ddcb 100644 --- a/samples/legacy_samples/fusion_sample.cpp +++ b/samples/legacy_samples/fusion_sample.cpp @@ -3953,7 +3953,9 @@ run_bn_bwd_weight(int64_t* xDim, // Create cudnn handle checkCudnnErr(cudnnCreate(&handle_)); - if (check_device_arch_newer_than("ampere") == false) { + // this example is only for Ampere and Hopper cards + bool is_supported = (is_ampere_arch() || is_hopper_arch()); + if (is_supported == false) { cudnn_frontend::set_error_and_throw_exception( nullptr, CUDNN_STATUS_ARCH_MISMATCH, diff --git a/samples/python/00_introduction.ipynb b/samples/python/00_introduction.ipynb index a1a1c4e..e7c8361 100644 --- a/samples/python/00_introduction.ipynb +++ b/samples/python/00_introduction.ipynb @@ -46,7 +46,7 @@ "outputs": [], "source": [ "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -94,7 +94,12 @@ "metadata": {}, "outputs": [], "source": [ - "graph = cudnn.pygraph(handle = handle, name = \"cudnn_graph_0\", io_data_type = cudnn.data_type.HALF, compute_data_type = cudnn.data_type.FLOAT)" + "graph = cudnn.pygraph(\n", + " handle=handle,\n", + " name=\"cudnn_graph_0\",\n", + " io_data_type=cudnn.data_type.HALF,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + ")" ] }, { @@ -116,7 +121,12 @@ "metadata": {}, "outputs": [], "source": [ - "X = graph.tensor(name = \"X\", dim = [8, 64, 56, 56], stride = [56 * 56 * 64, 1, 56 * 64 ,64], data_type=cudnn.data_type.HALF)" + "X = graph.tensor(\n", + " name=\"X\",\n", + " dim=[8, 64, 56, 56],\n", + " stride=[56 * 56 * 64, 1, 56 * 64, 64],\n", + " data_type=cudnn.data_type.HALF,\n", + ")" ] }, { @@ -125,7 +135,7 @@ "metadata": {}, "outputs": [], "source": [ - "W = graph.tensor(name = \"W\", dim = [32, 64, 3, 3], stride = [3 * 3 * 64, 1, 3 * 64 ,64])" + "W = graph.tensor(name=\"W\", dim=[32, 64, 3, 3], stride=[3 * 3 * 64, 1, 3 * 64, 64])" ] }, { @@ -144,7 +154,14 @@ "metadata": {}, "outputs": [], "source": [ - "Y = graph.conv_fprop(X, W, padding = [1,1], stride = [1,1], dilation = [1,1], compute_data_type = cudnn.data_type.FLOAT)" + "Y = graph.conv_fprop(\n", + " X,\n", + " W,\n", + " padding=[1, 1],\n", + " stride=[1, 1],\n", + " dilation=[1, 1],\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + ")" ] }, { @@ -206,9 +223,15 @@ "import torch\n", "\n", "\n", - "X_gpu = torch.randn(8, 64, 56, 56, requires_grad=False, device=\"cuda\", dtype=torch.float16).to(memory_format=torch.channels_last)\n", - "W_gpu = torch.randn(32, 64, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16).to(memory_format=torch.channels_last)\n", - "Y_gpu = torch.zeros(8, 32, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16).to(memory_format=torch.channels_last)\n", + "X_gpu = torch.randn(\n", + " 8, 64, 56, 56, requires_grad=False, device=\"cuda\", dtype=torch.float16\n", + ").to(memory_format=torch.channels_last)\n", + "W_gpu = torch.randn(\n", + " 32, 64, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16\n", + ").to(memory_format=torch.channels_last)\n", + "Y_gpu = torch.zeros(\n", + " 8, 32, 3, 3, requires_grad=False, device=\"cuda\", dtype=torch.float16\n", + ").to(memory_format=torch.channels_last)\n", "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)" ] }, @@ -225,7 +248,7 @@ "metadata": {}, "outputs": [], "source": [ - "graph.execute({X: X_gpu, W: W_gpu, Y: Y_gpu}, workspace, handle= handle)" + "graph.execute({X: X_gpu, W: W_gpu, Y: Y_gpu}, workspace, handle=handle)" ] }, { diff --git a/samples/python/01_matmul_bias.ipynb b/samples/python/01_matmul_bias.ipynb index 0edd399..c539b26 100644 --- a/samples/python/01_matmul_bias.ipynb +++ b/samples/python/01_matmul_bias.ipynb @@ -47,7 +47,7 @@ "source": [ "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n", "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -98,15 +98,15 @@ "input_type = torch.float16\n", "\n", "# input tensors\n", - "a = torch.randn(batch, m, k, dtype=input_type, device='cuda')\n", - "b = torch.randn(batch, k, n, dtype=input_type, device='cuda')\n", - "B = torch.randn(1, m, n, dtype=torch.float16, device='cuda')\n", + "a = torch.randn(batch, m, k, dtype=input_type, device=\"cuda\")\n", + "b = torch.randn(batch, k, n, dtype=input_type, device=\"cuda\")\n", + "B = torch.randn(1, m, n, dtype=torch.float16, device=\"cuda\")\n", "\n", "# reference output\n", "c_ref = torch.matmul(a, b) + B\n", "\n", "# place holder for cudnn output\n", - "c = torch.randn_like(c_ref, device='cuda')" + "c = torch.randn_like(c_ref, device=\"cuda\")" ] }, { @@ -122,16 +122,19 @@ "metadata": {}, "outputs": [], "source": [ - "graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT)\n", + "graph = cudnn.pygraph(\n", + " intermediate_data_type=cudnn.data_type.FLOAT,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + ")\n", "\n", "a_cudnn_tensor = graph.tensor_like(a)\n", "b_cudnn_tensor = graph.tensor_like(b)\n", "bias_cudnn_tensor = graph.tensor_like(B)\n", "\n", - "c_intermediate = graph.matmul(name = \"matmul\", A = a_cudnn_tensor, B = b_cudnn_tensor)\n", + "c_intermediate = graph.matmul(name=\"matmul\", A=a_cudnn_tensor, B=b_cudnn_tensor)\n", + "\n", + "c_cudnn_tensor = graph.bias(name=\"bias\", input=c_intermediate, bias=bias_cudnn_tensor)\n", "\n", - "c_cudnn_tensor = graph.bias(name = \"bias\", input = c_intermediate, bias = bias_cudnn_tensor)\n", - " \n", "c_cudnn_tensor.set_name(\"c\").set_output(True).set_data_type(cudnn.data_type.HALF)" ] }, @@ -186,7 +189,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.testing.assert_close(c, c_ref, rtol = 5e-3, atol = 5e-3)" + "torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)" ] } ], diff --git a/samples/python/02_sdpa_graph_serialization.ipynb b/samples/python/02_sdpa_graph_serialization.ipynb index af7f156..cd97eae 100644 --- a/samples/python/02_sdpa_graph_serialization.ipynb +++ b/samples/python/02_sdpa_graph_serialization.ipynb @@ -30,7 +30,7 @@ "source": [ "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n", "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -69,38 +69,47 @@ "metadata": {}, "outputs": [], "source": [ - "b = 2 # batch size\n", + "b = 2 # batch size\n", "\n", - "s_q = 1024 # query sequence length\n", - "s_kv = 1024 # key+value sequence length\n", + "s_q = 1024 # query sequence length\n", + "s_kv = 1024 # key+value sequence length\n", "\n", - "h = 6 # Query heads\n", + "h = 6 # Query heads\n", "\n", - "d = 64 # query+key embedding dimension per head\n", + "d = 64 # query+key embedding dimension per head\n", "\n", "shape_q = (b, h, s_q, d)\n", "shape_k = (b, h, s_kv, d)\n", "shape_v = (b, h, s_kv, d)\n", "shape_o = (b, h, s_q, d)\n", "\n", - "stride_q = (s_q * h * d, d, h * d, 1)\n", + "stride_q = (s_q * h * d, d, h * d, 1)\n", "stride_k = (s_kv * h * d, d, h * d, 1)\n", "stride_v = (s_kv * h * d, d, h * d, 1)\n", - "stride_o = (s_q * h * d, d, h * d, 1)\n", + "stride_o = (s_q * h * d, d, h * d, 1)\n", "\n", "attn_scale = 0.125\n", "\n", - "q_gpu = torch.randn(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_q, stride_q)\n", - "k_gpu = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_k, stride_k)\n", - "v_gpu = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_v, stride_v)\n", - "o_gpu = torch.empty(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(shape_o, stride_o)\n", + "q_gpu = torch.randn(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(\n", + " shape_q, stride_q\n", + ")\n", + "k_gpu = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(\n", + " shape_k, stride_k\n", + ")\n", + "v_gpu = torch.randn(b * h * s_kv * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(\n", + " shape_v, stride_v\n", + ")\n", + "o_gpu = torch.empty(b * h * s_q * d, dtype=torch.bfloat16, device=\"cuda\").as_strided(\n", + " shape_o, stride_o\n", + ")\n", "stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device=\"cuda\")\n", "\n", + "\n", "class UIDs(Enum):\n", - " Q_UID = 0\n", - " K_UID = 1\n", - " V_UID = 2\n", - " O_UID = 3\n", + " Q_UID = 0\n", + " K_UID = 1\n", + " V_UID = 2\n", + " O_UID = 3\n", " STATS_UID = 4" ] }, @@ -123,29 +132,34 @@ " io_data_type=cudnn.data_type.HALF,\n", " intermediate_data_type=cudnn.data_type.FLOAT,\n", " compute_data_type=cudnn.data_type.FLOAT,\n", - " handle = handle)\n", - " \n", + " handle=handle,\n", + " )\n", + "\n", " q = graph.tensor_like(q_gpu)\n", " k = graph.tensor_like(k_gpu)\n", " v = graph.tensor_like(v_gpu)\n", - " \n", - " o, stats = graph.sdpa(name=\"sdpa\",\n", - " q=q, k=k, v=v,\n", + "\n", + " o, stats = graph.sdpa(\n", + " name=\"sdpa\",\n", + " q=q,\n", + " k=k,\n", + " v=v,\n", " is_inference=False,\n", " attn_scale=attn_scale,\n", - " use_causal_mask=True)\n", - " \n", + " use_causal_mask=True,\n", + " )\n", + "\n", " o.set_output(True).set_dim(shape_o).set_stride(stride_o)\n", " stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n", - " \n", + "\n", " q.set_uid(UIDs.Q_UID.value)\n", " k.set_uid(UIDs.K_UID.value)\n", " v.set_uid(UIDs.V_UID.value)\n", " o.set_uid(UIDs.O_UID.value)\n", " stats.set_uid(UIDs.STATS_UID.value)\n", - " \n", + "\n", " graph.validate()\n", - " \n", + "\n", " return graph" ] }, @@ -163,11 +177,11 @@ "outputs": [], "source": [ "def check_support():\n", - " \n", + "\n", " graph = build_and_validate_graph_helper()\n", - " \n", + "\n", " graph.build_operation_graph()\n", - " \n", + "\n", " graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", "\n", " graph.check_support()" @@ -188,15 +202,15 @@ "source": [ "def serialize():\n", " graph = build_and_validate_graph_helper()\n", - " \n", + "\n", " graph.build_operation_graph()\n", - " \n", + "\n", " graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", "\n", " graph.check_support()\n", - " \n", + "\n", " graph.build_plans()\n", - " \n", + "\n", " return graph.serialize()" ] }, @@ -214,11 +228,11 @@ "outputs": [], "source": [ "def deserialize(payload):\n", - " \n", + "\n", " graph = cudnn.pygraph()\n", - " \n", + "\n", " graph.deserialize(payload)\n", - " \n", + "\n", " return graph" ] }, @@ -239,9 +253,11 @@ "\n", "data = serialize()\n", "\n", - "deserialized_graph = deserialize(data)\n", + "deserialized_graph = deserialize(data)\n", "\n", - "workspace = torch.empty(deserialized_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n", + "workspace = torch.empty(\n", + " deserialized_graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8\n", + ")\n", "\n", "variant_pack = {\n", " UIDs.Q_UID.value: q_gpu,\n", diff --git a/samples/python/03_mixed_precision_matmul.ipynb b/samples/python/03_mixed_precision_matmul.ipynb index ca95c7c..b1ad9ac 100644 --- a/samples/python/03_mixed_precision_matmul.ipynb +++ b/samples/python/03_mixed_precision_matmul.ipynb @@ -47,7 +47,7 @@ "source": [ "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n", "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -91,28 +91,31 @@ "# input data types can be different\n", "input_type_a = torch.int8\n", "input_type_b = torch.bfloat16\n", - "output_type = torch.bfloat16\n", + "output_type = torch.bfloat16\n", "\n", "# direct input data type for the matmul operation\n", "mma_data_type = torch.bfloat16\n", "\n", "# input tensors\n", "if input_type_a != torch.int8:\n", - " a = 2 * torch.randn(batch, m, k, dtype=input_type_a, device='cuda') - 0.5\n", + " a = 2 * torch.randn(batch, m, k, dtype=input_type_a, device=\"cuda\") - 0.5\n", "else:\n", - " a = torch.randint(4, (batch, m, k), dtype=input_type_a, device='cuda') - 1\n", + " a = torch.randint(4, (batch, m, k), dtype=input_type_a, device=\"cuda\") - 1\n", "\n", "if input_type_b != torch.int8:\n", - " b_row_major = 3 * torch.randn(batch, k, n, dtype=input_type_b, device='cuda') - 1.25\n", + " b_row_major = 3 * torch.randn(batch, k, n, dtype=input_type_b, device=\"cuda\") - 1.25\n", "else:\n", - " b_row_major = torch.randint(3, (batch, k, n), dtype=input_type_b, device='cuda').contiguous() - 2\n", + " b_row_major = (\n", + " torch.randint(3, (batch, k, n), dtype=input_type_b, device=\"cuda\").contiguous()\n", + " - 2\n", + " )\n", "b = torch.as_strided(b_row_major, (batch, k, n), (n * k, 1, n))\n", "\n", "# reference output\n", "c_ref = torch.matmul(a.to(mma_data_type), b.to(mma_data_type)).to(output_type)\n", "\n", "# place holder for cudnn output\n", - "c = torch.randn_like(c_ref, device='cuda')" + "c = torch.randn_like(c_ref, device=\"cuda\")" ] }, { @@ -135,7 +138,9 @@ "\n", "# cudnn will do the following conversion path: input_data_type -> compute_data_type -> output_data_type\n", "# compute_data_type can be int32 as well\n", - "a_cudnn_tensor_casted = graph.identity(input = a_cudnn_tensor, compute_data_type=cudnn.data_type.FLOAT)\n", + "a_cudnn_tensor_casted = graph.identity(\n", + " input=a_cudnn_tensor, compute_data_type=cudnn.data_type.FLOAT\n", + ")\n", "a_cudnn_tensor_casted.set_data_type(mma_data_type)\n", "\n", "# here we omit the code casting tensor b to the mma_data_type\n", @@ -143,7 +148,12 @@ "# user can also cast tensor b if it has a different input_type from the mma_data_type\n", "\n", "# compute_data_type should be set to int32 if the mma_data_type is int8\n", - "c_cudnn_tensor = graph.matmul(name = \"matmul\", A = a_cudnn_tensor_casted, B = b_cudnn_tensor, compute_data_type = cudnn.data_type.FLOAT)\n", + "c_cudnn_tensor = graph.matmul(\n", + " name=\"matmul\",\n", + " A=a_cudnn_tensor_casted,\n", + " B=b_cudnn_tensor,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + ")\n", "c_cudnn_tensor.set_name(\"c\").set_output(True).set_data_type(output_type)" ] }, @@ -197,7 +207,7 @@ "metadata": {}, "outputs": [], "source": [ - "torch.testing.assert_close(c, c_ref, rtol = 5e-3, atol = 5e-3)" + "torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)" ] } ], diff --git a/samples/python/50_scaled_dot_product_attention.ipynb b/samples/python/50_scaled_dot_product_attention.ipynb index 872b39c..f24f8af 100644 --- a/samples/python/50_scaled_dot_product_attention.ipynb +++ b/samples/python/50_scaled_dot_product_attention.ipynb @@ -49,7 +49,7 @@ "source": [ "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n", "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -67,8 +67,12 @@ "handle = cudnn.create_handle()\n", "\n", "assert torch.cuda.is_available()\n", - "assert torch.cuda.get_device_capability()[0] >= 8, \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n", - "assert cudnn.backend_version() >= 8903, \"SDPA operation is only supported cuDNN version 8.9.3 or above\"" + "assert (\n", + " torch.cuda.get_device_capability()[0] >= 8\n", + "), \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n", + "assert (\n", + " cudnn.backend_version() >= 8903\n", + "), \"SDPA operation is only supported cuDNN version 8.9.3 or above\"" ] }, { @@ -88,10 +92,10 @@ "metadata": {}, "outputs": [], "source": [ - "b = 4 # batch size\n", - "h = 12 # query number of heads\n", - "s = 1024 # maximum sequence length\n", - "d = 64 # embedding dimension per head\n", + "b = 4 # batch size\n", + "h = 12 # query number of heads\n", + "s = 1024 # maximum sequence length\n", + "d = 64 # embedding dimension per head\n", "\n", "attn_scale = 1.0 / math.sqrt(d)" ] @@ -148,7 +152,9 @@ "# causal mask is enabled\n", "o, _ = graph.sdpa(\n", " name=\"sdpa\",\n", - " q=q, k=k, v=v,\n", + " q=q,\n", + " k=k,\n", + " v=v,\n", " is_inference=True,\n", " attn_scale=attn_scale,\n", " use_causal_mask=True,\n", @@ -220,7 +226,9 @@ "k_ref = k_gpu.detach().float().requires_grad_()\n", "v_ref = v_gpu.detach().float().requires_grad_()\n", "\n", - "o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)\n", + "o_ref = torch.nn.functional.scaled_dot_product_attention(\n", + " q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale\n", + ")\n", "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)" ] } diff --git a/samples/python/51_scaled_dot_product_attention_backward.ipynb b/samples/python/51_scaled_dot_product_attention_backward.ipynb index d294872..8ee2f02 100644 --- a/samples/python/51_scaled_dot_product_attention_backward.ipynb +++ b/samples/python/51_scaled_dot_product_attention_backward.ipynb @@ -45,7 +45,7 @@ "source": [ "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n", "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", - "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12 | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" ] }, @@ -63,8 +63,12 @@ "handle = cudnn.create_handle()\n", "\n", "assert torch.cuda.is_available()\n", - "assert torch.cuda.get_device_capability()[0] >= 8, \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n", - "assert cudnn.backend_version() >= 8903, \"SDPA operation is only supported cuDNN version 8.9.3 or above\"" + "assert (\n", + " torch.cuda.get_device_capability()[0] >= 8\n", + "), \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n", + "assert (\n", + " cudnn.backend_version() >= 8903\n", + "), \"SDPA operation is only supported cuDNN version 8.9.3 or above\"" ] }, { @@ -84,10 +88,10 @@ "metadata": {}, "outputs": [], "source": [ - "b = 4 # batch size\n", - "h = 12 # query number of heads\n", - "s = 1024 # maximum sequence length\n", - "d = 64 # embedding dimension per head\n", + "b = 4 # batch size\n", + "h = 12 # query number of heads\n", + "s = 1024 # maximum sequence length\n", + "d = 64 # embedding dimension per head\n", "\n", "attn_scale = 1.0 / math.sqrt(d)" ] @@ -169,7 +173,9 @@ "# causal mask is enabled\n", "o_forward, stats_forward = graph_forward.sdpa(\n", " name=\"sdpa\",\n", - " q=q_forward, k=k_forward, v=v_forward,\n", + " q=q_forward,\n", + " k=k_forward,\n", + " v=v_forward,\n", " is_inference=False,\n", " attn_scale=attn_scale,\n", " use_causal_mask=True,\n", @@ -214,8 +220,12 @@ "\n", "dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(\n", " name=\"sdpa_backward\",\n", - " q=q_backward, k=k_backward, v=v_backward,\n", - " o=o_backward, dO=dO_backward, stats=stats_backward,\n", + " q=q_backward,\n", + " k=k_backward,\n", + " v=v_backward,\n", + " o=o_backward,\n", + " dO=dO_backward,\n", + " stats=stats_backward,\n", " attn_scale=attn_scale,\n", " use_causal_mask=True,\n", ")\n", @@ -323,10 +333,14 @@ "v_ref = v_gpu.detach().float().requires_grad_()\n", "dO_ref = dO_gpu.detach().float()\n", "\n", - "o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)\n", + "o_ref = torch.nn.functional.scaled_dot_product_attention(\n", + " q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale\n", + ")\n", "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)\n", "\n", - "dQ_ref, dK_ref, dV_ref = torch.autograd.grad(outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref])\n", + "dQ_ref, dK_ref, dV_ref = torch.autograd.grad(\n", + " outputs=[o_ref], inputs=[q_ref, k_ref, v_ref], grad_outputs=[dO_ref]\n", + ")\n", "torch.testing.assert_close(dQ_ref, dQ_gpu.float(), atol=5e-3, rtol=3e-3)\n", "torch.testing.assert_close(dK_ref, dK_gpu.float(), atol=5e-3, rtol=3e-3)\n", "torch.testing.assert_close(dV_ref, dV_gpu.float(), atol=5e-3, rtol=3e-3)" diff --git a/samples/utils/helpers.h b/samples/utils/helpers.h index 63e18ac..8dc60c0 100644 --- a/samples/utils/helpers.h +++ b/samples/utils/helpers.h @@ -300,6 +300,10 @@ struct Surface { T_ELEM* hostPtr = NULL; int64_t n_elems = 0; + protected: + explicit Surface() {} + + public: explicit Surface(int64_t n_elems, [[maybe_unused]] bool hasRef) : n_elems(n_elems) { checkCudaErr(cudaMalloc((void**)&(devPtr), (size_t)((n_elems) * sizeof(devPtr[0])))); hostPtr = (T_ELEM*)calloc((size_t)n_elems, sizeof(hostPtr[0])); @@ -332,6 +336,32 @@ struct Surface { checkCudaErr(cudaDeviceSynchronize()); } + Surface(const Surface& other) : n_elems(n_elems) { + checkCudaErr(cudaMalloc((void**)&(devPtr), (size_t)((n_elems) * sizeof(devPtr[0])))); + hostPtr = (T_ELEM*)calloc((size_t)n_elems, sizeof(hostPtr[0])); + std::copy(other.hostPtr, other.hostPtr + n_elems, hostPtr); + checkCudaErr(cudaMemcpy(devPtr, hostPtr, size_t(sizeof(hostPtr[0]) * n_elems), cudaMemcpyHostToDevice)); + checkCudaErr(cudaDeviceSynchronize()); + } + + Surface(Surface&& other) noexcept : Surface() { swap(*this, other); } + + Surface& + operator=(Surface other) { + swap(*this, other); + + return *this; + } + + friend void + swap(Surface& first, Surface& second) { + using std::swap; + + swap(first.n_elems, second.n_elems); + swap(first.hostPtr, second.hostPtr); + swap(first.devPtr, second.devPtr); + } + ~Surface() { if (devPtr) { cudaFree(devPtr); diff --git a/test/python_fe/conftest.py b/test/python_fe/conftest.py index a5d11ab..f3f5b9f 100644 --- a/test/python_fe/conftest.py +++ b/test/python_fe/conftest.py @@ -28,3 +28,8 @@ def pytest_addoption(parser): parser.addoption( "--mha_h_v", default=None, help="[test_mhas.py] value number of heads" ) + parser.addoption( + "--mha_deterministic", + default=None, + help="[test_mhas.py] force deterministic algorithm", + ) diff --git a/test/python_fe/test_apply_rope.py b/test/python_fe/test_apply_rope.py index 5508645..0412dcc 100644 --- a/test/python_fe/test_apply_rope.py +++ b/test/python_fe/test_apply_rope.py @@ -77,7 +77,7 @@ def test_apply_rope(): sin2_gpu = sin_gpu[..., rope_n_elem // 2 :] handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -129,5 +129,8 @@ def test_apply_rope(): handle=handle, ) + torch.cuda.synchronize() # Compare torch.testing.assert_close(Y_expected, x_gpu, atol=1e-2, rtol=1e-2) + + cudnn.destroy_handle(handle) diff --git a/test/python_fe/test_batchnorm.py b/test/python_fe/test_batchnorm.py index 4535ffb..c13b56e 100644 --- a/test/python_fe/test_batchnorm.py +++ b/test/python_fe/test_batchnorm.py @@ -57,7 +57,7 @@ def test_bn_relu_with_mask(): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code @@ -181,12 +181,15 @@ def test_bn_relu_with_mask(): ) # Compare + torch.cuda.synchronize() print("Comparing outputs") torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) torch.testing.assert_close(mean_expected, saved_mean_actual, atol=1e-3, rtol=1e-3) torch.testing.assert_close( inv_var_expected, saved_inv_var_actual, atol=1e-3, rtol=1e-3 ) + + cudnn.destroy_handle(handle) # torch.testing.assert_close(mask_expected, mask_actual) @@ -220,7 +223,7 @@ def test_drelu_dadd_dbn(): ).to(memory_format=torch.channels_last) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code @@ -342,7 +345,7 @@ def test_bn_infer_drelu_dbn(): ).to(memory_format=torch.channels_last) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code diff --git a/test/python_fe/test_conv_bias.py b/test/python_fe/test_conv_bias.py index 26ff782..aa1e674 100644 --- a/test/python_fe/test_conv_bias.py +++ b/test/python_fe/test_conv_bias.py @@ -7,13 +7,23 @@ class CSBR(torch.nn.Module): - def forward(self, x, w, b=None, padding=[1, 1], stride=[1, 1], dilation=[1, 1]): + def forward( + self, + x, + w, + b=None, + padding=[1, 1], + stride=[1, 1], + dilation=[1, 1], + lower_clip=0.0, + upper_clip=128, + ): if b is not None: b = b.reshape(-1) # Conv2d needs a 1D tensor conv_output = torch.nn.functional.conv2d( x, w, bias=b, padding=padding, stride=stride, dilation=dilation ) - return torch.nn.functional.relu(conv_output) + return torch.clamp(conv_output, min=lower_clip, max=upper_clip) @torch_fork_set_rng(seed=0) @@ -34,11 +44,18 @@ def test_conv_bias_relu(): dilation = [1, 1] model = CSBR().eval().to("cuda").to(torch.float16) Y_expected = model( - X_gpu, W_gpu, b=B_gpu, padding=padding, stride=stride, dilation=dilation + X_gpu, + W_gpu, + b=B_gpu, + padding=padding, + stride=stride, + dilation=dilation, + lower_clip=0.5, + upper_clip=0.55, ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -69,7 +86,7 @@ def test_conv_bias_relu(): bias_output = graph.bias(name="bias", input=conv_output, bias=B) - Y = graph.relu(name="relu", input=bias_output) + Y = graph.relu(name="relu", input=bias_output, lower_clip=0.5, upper_clip=0.55) Y.set_output(True) graph.validate() @@ -85,6 +102,8 @@ def test_conv_bias_relu(): Y_actual = torch.zeros_like(Y_expected) graph.execute({X: X_gpu, W: W_gpu, B: B_gpu, Y: Y_actual}, workspace, handle=handle) + torch.cuda.synchronize() + torch.testing.assert_close(Y_expected, Y_actual, atol=0.05, rtol=1e-2) cudnn.destroy_handle(handle) @@ -104,10 +123,18 @@ def test_conv_relu(): stride = [2, 3] dilation = [1, 1] model = CSBR().eval().to("cuda").to(torch.float16) - Y_expected = model(X_gpu, W_gpu, padding=padding, stride=stride, dilation=dilation) + Y_expected = model( + X_gpu, + W_gpu, + padding=padding, + stride=stride, + dilation=dilation, + lower_clip=0.5, + upper_clip=0.55, + ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code @@ -129,7 +156,7 @@ def test_conv_relu(): image=X, weight=W, padding=padding, stride=stride, dilation=dilation ) - Y = graph.relu(name="relu", input=conv_output) + Y = graph.relu(name="relu", input=conv_output, lower_clip=0.5, upper_clip=0.55) Y.set_output(True) graph.validate() @@ -146,7 +173,9 @@ def test_conv_relu(): handle = cudnn.create_handle() graph.execute({X: X_gpu, W: W_gpu, Y: Y_actual}, workspace, handle=handle) # Compare + torch.cuda.synchronize() torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) + cudnn.destroy_handle(handle) @torch_fork_set_rng(seed=0) @@ -187,7 +216,7 @@ def test_conv3d_bias_leaky_relu(): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -234,6 +263,7 @@ def test_conv3d_bias_leaky_relu(): torch.cuda.synchronize() torch.testing.assert_close(Y_expected, Y_actual, atol=1e-2, rtol=1e-2) + cudnn.destroy_handle(handle) @torch_fork_set_rng(seed=0) @@ -256,7 +286,7 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): Y_expected = dleaky_relu(loss_gpu, input_gpu, negative_slope) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -297,7 +327,9 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): {loss: loss_gpu, input: input_gpu, Y: Y_actual}, workspace, handle=handle ) + torch.cuda.synchronize() torch.testing.assert_close(Y_expected, Y_actual, atol=1e-4, rtol=1e-4) + cudnn.destroy_handle(handle) @pytest.mark.skipif( @@ -338,7 +370,7 @@ def test_conv_int8(): compare_output = False handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -378,6 +410,8 @@ def test_conv_int8(): if compare_output: torch.testing.assert_close(Y_expected, Y_actual, atol=1e-2, rtol=1e-2) + cudnn.destroy_handle(handle) + if __name__ == "__main__": # test_conv_int8() diff --git a/test/python_fe/test_conv_genstats.py b/test/python_fe/test_conv_genstats.py index 7dace26..5735428 100644 --- a/test/python_fe/test_conv_genstats.py +++ b/test/python_fe/test_conv_genstats.py @@ -62,7 +62,7 @@ def test_conv_genstats(): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code @@ -128,9 +128,11 @@ def test_conv_genstats(): ) # Compare + torch.cuda.synchronize() torch.testing.assert_close(sum_expected, sum_dev, atol=0.5, rtol=1e-2) torch.testing.assert_close(sq_sum_expected, sq_sum_dev, atol=1e-3, rtol=1e-3) torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) + cudnn.destroy_handle(handle) if __name__ == "__main__": diff --git a/test/python_fe/test_conv_reduction.py b/test/python_fe/test_conv_reduction.py index a91a4e5..3df5392 100644 --- a/test/python_fe/test_conv_reduction.py +++ b/test/python_fe/test_conv_reduction.py @@ -28,7 +28,7 @@ def test_reduction(): Y_expected = conv_output.sum(dim=1) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Cudnn code @@ -66,9 +66,12 @@ def test_reduction(): graph.execute({X: X_gpu, Weight: W_gpu, Y: Y_actual}, workspace, handle=handle) + torch.cuda.synchronize() # Compare torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=1e-3) + cudnn.destroy_handle(handle) + if __name__ == "__main__": test_reduction() diff --git a/test/python_fe/test_instancenorm.py b/test/python_fe/test_instancenorm.py index a5dbd11..fd354f0 100644 --- a/test/python_fe/test_instancenorm.py +++ b/test/python_fe/test_instancenorm.py @@ -71,7 +71,7 @@ def test_in(param_extract): print("Building cudnn graph") handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -127,11 +127,13 @@ def test_in(param_extract): handle=handle, ) + torch.cuda.synchronize() print("Comparing with reference") torch.testing.assert_close(Y_expected, Y_actual, atol=atol, rtol=rtol) torch.testing.assert_close(mean_expected, mean_actual, atol=atol, rtol=rtol) torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=atol, rtol=rtol) print("Success!!") + cudnn.destroy_handle(handle) target = torch.randn_like(Y_expected) criterion = torch.nn.MSELoss() @@ -145,7 +147,7 @@ def test_in(param_extract): loss.backward() handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) bwd_graph = cudnn.pygraph( @@ -214,6 +216,7 @@ def test_in(param_extract): torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-3, rtol=2e-3) torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=2e-3, rtol=2e-3) torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=2e-3, rtol=2e-3) + cudnn.destroy_handle(handle) print("Success!!") diff --git a/test/python_fe/test_layernorm.py b/test/python_fe/test_layernorm.py index ea1fb9e..5e2e131 100644 --- a/test/python_fe/test_layernorm.py +++ b/test/python_fe/test_layernorm.py @@ -80,7 +80,7 @@ def test_layernorm(param_extract): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -223,9 +223,12 @@ def test_layernorm(param_extract): handle=handle, ) + torch.cuda.synchronize() + torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-4, rtol=2e-4) torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=2e-4, rtol=2e-4) torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=2e-4, rtol=2e-4) + cudnn.destroy_handle(handle) if __name__ == "__main__": diff --git a/test/python_fe/test_matmul_bias_relu.py b/test/python_fe/test_matmul_bias_relu.py index 9394685..ca81a2f 100644 --- a/test/python_fe/test_matmul_bias_relu.py +++ b/test/python_fe/test_matmul_bias_relu.py @@ -56,7 +56,7 @@ def test_int8_bf16_matmul(): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Make cudnn graph @@ -88,8 +88,10 @@ def test_int8_bf16_matmul(): ) graph.execute({A: A_gpu, B: B_gpu, C: C_actual}, workspace, handle=handle) + torch.cuda.synchronize() # compare'em torch.testing.assert_close(C_expected, C_actual) + cudnn.destroy_handle(handle) A_data_type_options = [torch.int8, torch.bfloat16, torch.float16] @@ -149,7 +151,7 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type): B_gpu = torch.as_strided(B_gpu_strided, (B, K, N), (N * K, 1, N)) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # Make cudnn graph @@ -198,8 +200,10 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type): ) graph.execute({A: A_gpu, B: B_gpu, C: C_actual}, workspace, handle=handle) + torch.cuda.synchronize() # compare'em torch.testing.assert_close(C_expected, C_actual, atol=1e-4, rtol=1e-4) + cudnn.destroy_handle(handle) problem_size_options = [(1, 128, 768), (16, 512, 1600), (1, 128, 1024)] @@ -240,7 +244,7 @@ def test_matmul_bias_relu(param_extract): ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -293,6 +297,8 @@ def test_matmul_bias_relu(param_extract): torch.testing.assert_close(Y_expected, Y_actual, atol=atol, rtol=rtol) + cudnn.destroy_handle(handle) + if __name__ == "__main__": test_matmul_bias_relu(((1, 128, 1600), torch.float16)) diff --git a/test/python_fe/test_mhas.py b/test/python_fe/test_mhas.py index c484800..95ff342 100644 --- a/test/python_fe/test_mhas.py +++ b/test/python_fe/test_mhas.py @@ -10,12 +10,14 @@ from test_utils import torch_fork_set_rng input_type_options = [torch.float16, torch.bfloat16] -layout_options = ["non_interleaved", "bs3hd", "sbh3d"] +layout_options = ["bshd_bshd_bshd", "bs3hd", "sbh3d"] head_group_options = ["multi_head", "group_query", "multi_query"] bias_options = [False, True] alibi_mask_options = [False, True] padding_mask_options = [False, True] causal_mask_options = [False, True] +causal_mask_bottom_right_options = [False, True] +sliding_window_mask_options = [False, True] dropout_options = [False, True] ragged_options = [False, True] is_infer_options = [False, True] @@ -50,6 +52,8 @@ def compute_ref( is_alibi=False, padding=None, is_causal=False, + is_causal_bottom_right=False, + sliding_window_length=None, dropout_prob=0.0, dropout_mask=None, compute_stats=False, @@ -79,6 +83,16 @@ def compute_ref( v = v.expand(-1, -1, h_q // h_v, -1, -1) v = v.reshape(v.size(0), -1, v.size(3), v.size(4)) + if is_causal_bottom_right: + causal_mask_bottom_right_zero = torch.ones( + 1, 1, s_q, 1, dtype=torch.bool, device=device + ) + causal_mask_bottom_right_zero[:, :, : s_q - s_kv, :] = False + q = q * causal_mask_bottom_right_zero + if sliding_window_length is not None: + swa_mask_zero = torch.ones(1, 1, s_q, 1, dtype=torch.bool, device=device) + swa_mask_zero[:, :, s_kv + sliding_window_length - 1 :, :] = False + q = q * swa_mask_zero # generate masks to compute reference values for padding mask # (also called variable sequence length) if padding is not None: @@ -94,8 +108,6 @@ def compute_ref( v_mask[i, :, n:, :] = False s_mask[i, :, :, n:] = True p_mask[i, :, m:, :] = False - - if padding is not None: q = q * q_mask k = k * k_mask v = v * v_mask @@ -138,12 +150,28 @@ def compute_ref( if padding is not None: s = s.masked_fill(s_mask, float("-inf")) if is_causal: - causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device).triu_( - diagonal=1 - ) + causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + causal_mask.triu_(diagonal=1) s = s.masked_fill(causal_mask, float("-inf")) + if is_causal_bottom_right: + causal_mask_bottom_right = torch.ones( + s_q, s_kv, dtype=torch.bool, device=device + ) + causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1) + causal_mask_bottom_right &= causal_mask_bottom_right_zero.view(s_q, 1) + s = s.masked_fill(causal_mask_bottom_right, float("-inf")) + if sliding_window_length is not None: + assert is_causal == True + swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + swa_mask.tril_(diagonal=-1 * sliding_window_length) + swa_mask &= swa_mask_zero.view(s_q, 1) + s = s.masked_fill(swa_mask, float("-inf")) p = torch.softmax(s, dim=-1) + if is_causal_bottom_right: + p = p * causal_mask_bottom_right_zero + if sliding_window_length is not None: + p = p * swa_mask_zero if padding is not None: p = p * p_mask @@ -168,7 +196,23 @@ def compute_ref( return o -def generate_layout(layout, head_group, shape_q, shape_k, shape_v, shape_o): +# Generator for layout combinations +# | layout | GQA | Packed | GQA and Packed | +# |-----------------|-----------------|-------------|----------------| +# | bshd_bshd_bshd | bshd_bshd_bshd | thd_thd_thd | thd_thd_thd | +# | bs3hd | bshd_bs2hd | t3hd | thd_t2hd | +# | sbh3d | sbhd_sbh2d | | | +def generate_layout( + layout, + head_group, + shape_q, + shape_k, + shape_v, + shape_o, + is_packed=False, + seq_len_q=None, + seq_len_kv=None, +): b, h_q, s_q, d_qk = shape_q b, h_k, s_kv, d_qk = shape_k b, h_v, s_kv, d_v = shape_v @@ -179,8 +223,83 @@ def generate_layout(layout, head_group, shape_q, shape_k, shape_v, shape_o): assert shape_v == (b, h_v, s_kv, d_v) assert shape_o == (b, h_q, s_q, d_v) - if layout == "sbh3d": + if layout == "bshd_bshd_bshd": + if not is_packed: + # bshd_bshd_bshd + stride_q = (s_q * h_q * d_qk, d_qk, h_q * d_qk, 1) + stride_k = (s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1) + stride_v = (s_kv * h_v * d_v, d_v, h_v * d_v, 1) + stride_o = (s_q * h_q * d_v, d_v, h_q * d_v, 1) + offset_q = 0 + offset_k = offset_q + b * s_q * h_q * d_qk + offset_v = offset_k + b * s_kv * h_k * d_qk + else: + # thd_thd_thd + assert seq_len_q is not None + assert seq_len_kv is not None + t_q = torch.sum(seq_len_q) + t_kv = torch.sum(seq_len_kv) + stride_q = (s_q * h_q * d_qk, d_qk, h_q * d_qk, 1) + stride_k = (s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1) + stride_v = (s_kv * h_v * d_v, d_v, h_v * d_v, 1) + stride_o = (s_q * h_q * d_v, d_v, h_q * d_v, 1) + offset_q = 0 + offset_k = offset_q + t_q * h_q * d_qk + offset_v = offset_k + t_kv * h_k * d_qk + elif layout == "bs3hd": + if not is_packed: + if head_group == "multi_head": + # bs3hd + assert (h_q == h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) + h, s, d = h_q, s_q, d_qk + stride_q = (s * 3 * h * d, d, 3 * h * d, 1) + stride_k = (s * 3 * h * d, d, 3 * h * d, 1) + stride_v = (s * 3 * h * d, d, 3 * h * d, 1) + stride_o = (s * h * d, d, h * d, 1) + offset_q = 0 + offset_k = offset_q + h * d + offset_v = offset_k + h * d + else: + # bshd_bs2hd + assert (h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) + h_kv, s, d = h_k, s_q, d_qk + stride_q = (s * h_q * d, d, h_q * d, 1) + stride_k = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) + stride_v = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) + stride_o = (s * h_q * d, d, h_q * d, 1) + offset_q = 0 + offset_k = offset_q + s * b * h_q * d + offset_v = offset_k + h_kv * d + else: # is_packed + assert seq_len_q is not None + assert seq_len_kv is not None + t_q = torch.sum(seq_len_q) + t_kv = torch.sum(seq_len_kv) + if head_group == "multi_head": + # t3hd + assert (h_q == h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) + h, s, d = h_q, s_q, d_qk + stride_q = (s * 3 * h * d, d, 3 * h * d, 1) + stride_k = (s * 3 * h * d, d, 3 * h * d, 1) + stride_v = (s * 3 * h * d, d, 3 * h * d, 1) + stride_o = (s * h * d, d, h * d, 1) + offset_q = 0 + offset_k = offset_q + h * d + offset_v = offset_k + h * d + else: + # thd_t2hd + assert (h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) + h_kv, s, d = h_k, s_q, d_qk + stride_q = (s * h_q * d, d, h_q * d, 1) + stride_k = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) + stride_v = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) + stride_o = (s * h_q * d, d, h_q * d, 1) + offset_q = 0 + offset_k = offset_q + t_q * h_q * d + offset_v = offset_k + h_kv * d + elif layout == "sbh3d": if head_group == "multi_head": + # sbh3d assert (h_q == h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) h, s, d = h_q, s_q, d_qk stride_q = (h * 3 * d, 3 * d, b * h * 3 * d, 1) @@ -191,8 +310,7 @@ def generate_layout(layout, head_group, shape_q, shape_k, shape_v, shape_o): offset_k = offset_q + d offset_v = offset_k + d else: - # group_query and multi_query - # sbhd + sbh2d + # sbhd_sbh2d assert (h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) h_kv, s, d = h_k, s_q, d_qk stride_q = (h_q * d, d, b * h_q * d, 1) @@ -202,104 +320,119 @@ def generate_layout(layout, head_group, shape_q, shape_k, shape_v, shape_o): offset_q = 0 offset_k = offset_q + s * b * h_q * d offset_v = offset_k + d - elif layout == "bs3hd": - if head_group == "multi_head": - assert (h_q == h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) - h, s, d = h_q, s_q, d_qk - stride_q = (s * 3 * h * d, d, 3 * h * d, 1) - stride_k = (s * 3 * h * d, d, 3 * h * d, 1) - stride_v = (s * 3 * h * d, d, 3 * h * d, 1) - stride_o = (s * h * d, d, h * d, 1) - offset_q = 0 - offset_k = offset_q + h_q * d - offset_v = offset_k + h_k * d - else: - # group_query and multi_query - # bshd + bs2hd - assert (h_k == h_v) and (s_q == s_kv) and (d_qk == d_v) - h_kv, s, d = h_k, s_q, d_qk - stride_q = (s * h_q * d, d, h_q * d, 1) - stride_k = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) - stride_v = (s * 2 * h_kv * d, d, 2 * h_kv * d, 1) - stride_o = (s * h_q * d, d, h_q * d, 1) - offset_q = 0 - offset_k = offset_q + s * b * h_q * d - offset_v = offset_k + h_kv * d else: - # bshd non_interleaved layout - stride_q = (s_q * h_q * d_qk, d_qk, h_q * d_qk, 1) - stride_k = (s_kv * h_k * d_qk, d_qk, h_k * d_qk, 1) - stride_v = (s_kv * h_v * d_v, d_v, h_v * d_v, 1) - stride_o = (s_q * h_q * d_v, d_v, h_q * d_v, 1) - offset_q = 0 - offset_k = offset_q + b * s_q * h_q * d_qk - offset_v = offset_k + b * s_kv * h_k * d_qk + raise ValueError("layout must be 'bshd_bshd_bshd', 'bs3hd', or 'sbh3d'") return stride_q, stride_k, stride_v, stride_o, offset_q, offset_k, offset_v -def compute_exclusive_prefix_sum(tensor): +def generate_ragged_offset( + layout, head_group, shape_q, shape_k, shape_v, shape_o, seq_len_q, seq_len_kv +): + b, h_q, s_q, d_qk = shape_q + b, h_k, s_kv, d_qk = shape_k + b, h_v, s_kv, d_v = shape_v + b, h_q, s_q, d_v = shape_o + + assert shape_q == (b, h_q, s_q, d_qk) + assert shape_k == (b, h_k, s_kv, d_qk) + assert shape_v == (b, h_v, s_kv, d_v) + assert shape_o == (b, h_q, s_q, d_v) + + # Compute the exclusive prefix sum for ragged sequence dimension # tensor has shape (B, 1, 1, 1) # output has shape (B+1, 1, 1, 1) # ex) tensor = [[[[2, 4, 1, 6]]]] # output = [[[[0, 2, 6, 7, 13]]]] - assert tensor.size(1) == tensor.size(2) == tensor.size(3) == 1 - return torch.cat( - ( - torch.zeros(1, 1, 1, 1, dtype=tensor.dtype, device=tensor.device), - torch.cumsum(tensor, dim=0), + def compute_exclusive_prefix_sum(tensor): + assert tensor.size(1) == tensor.size(2) == tensor.size(3) == 1 + return torch.cat( + ( + torch.zeros(1, 1, 1, 1, dtype=tensor.dtype, device=tensor.device), + torch.cumsum(tensor, dim=0), + ) ) - ) + if layout == "bshd_bshd_bshd": + # thd_thd_thd + q_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d_qk + k_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * h_k * d_qk + v_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * h_v * d_v + o_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d_v + elif layout == "bs3hd": + if head_group == "multi_head": + # t3hd + assert torch.equal(seq_len_q, seq_len_kv) + assert (h_q == h_k == h_v) and (d_qk == d_v) + seq_len, h, d = seq_len_q, h_q, d_qk + q_ragged_offset = compute_exclusive_prefix_sum(seq_len) * 3 * h * d + k_ragged_offset = compute_exclusive_prefix_sum(seq_len) * 3 * h * d + v_ragged_offset = compute_exclusive_prefix_sum(seq_len) * 3 * h * d + o_ragged_offset = compute_exclusive_prefix_sum(seq_len) * h * d + else: + # thd_t2hd + assert (h_k == h_v) and (d_qk == d_v) + seq_len, h_kv, d = seq_len_q, h_k, d_qk + q_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d + k_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * 2 * h_kv * d + v_ragged_offset = compute_exclusive_prefix_sum(seq_len_kv) * 2 * h_kv * d + o_ragged_offset = compute_exclusive_prefix_sum(seq_len_q) * h_q * d + else: + raise ValueError() -def convert_ragged_to_uniform(ragged_tensor, ragged_offset): + q_ragged_offset = q_ragged_offset.to(dtype=seq_len_q.dtype) + k_ragged_offset = k_ragged_offset.to(dtype=seq_len_kv.dtype) + v_ragged_offset = v_ragged_offset.to(dtype=seq_len_kv.dtype) + o_ragged_offset = o_ragged_offset.to(dtype=seq_len_q.dtype) + + return q_ragged_offset, k_ragged_offset, v_ragged_offset, o_ragged_offset + + +def convert_ragged_to_uniform(ragged_tensor, seq_len): # limitations: - # 1. tensor is non-interleaved with bhsd dim order and bshd stride order + # 1. tensor is bhsd dim order and bshd stride order (may be interleaved) # 2. ragged tensor is packed and in-order, therefore # ragged offset is monatomically increasing assert ragged_tensor.dim() == 4 b, h, s, d = ragged_tensor.size() b_stride, h_stride, s_stride, d_stride = ragged_tensor.stride() assert b_stride >= s_stride >= h_stride >= d_stride - assert ragged_offset.dim() == 4 and (b + 1, 1, 1, 1) == ragged_offset.size() + assert seq_len.dim() == 4 and (b, 1, 1, 1) == seq_len.size() # ragged offset is given in 4D, convert to 1D locally - ragged_offset = ragged_offset.flatten() + seq_len = seq_len.flatten() # convert bhsd to bshd and flatten - ragged_tensor_flat = torch.einsum("bhsd->bshd", ragged_tensor).flatten() - uniform_tensor_flat = torch.zeros_like(ragged_tensor_flat) + uniform_tensor = torch.zeros(b, s, h, d).to( + dtype=ragged_tensor.dtype, device=ragged_tensor.device + ) + ragged_tensor_thd = torch.einsum("bhsd->bshd", ragged_tensor).reshape(b * s, h, d) # copy - for i, num_elements in enumerate(ragged_offset[1:] - ragged_offset[:-1]): - unif_a = i * s * h * d - unif_b = unif_a + num_elements - ragg_a = ragged_offset[i] - ragg_b = ragg_a + num_elements - uniform_tensor_flat[unif_a:unif_b] = ragged_tensor_flat[ragg_a:ragg_b] - - # unflatten and convert bshd to bhsd - uniform_tensor = uniform_tensor_flat.view(b, s, h, d) + t = 0 + for b, s in enumerate(seq_len): + uniform_tensor[b, 0:s, :, :] = ragged_tensor_thd[t : t + s, :, :] + t += s + + # convert back to bshd to bhsd uniform_tensor = torch.einsum("bshd->bhsd", uniform_tensor) return uniform_tensor +# fmt: off @pytest.mark.parametrize("is_infer", is_infer_options, ids=lambda p: f"infer{int(p)}") @pytest.mark.parametrize("is_ragged", ragged_options, ids=lambda p: f"ragged{int(p)}") -@pytest.mark.parametrize( - "is_dropout", dropout_options, ids=lambda p: f"dropout{int(p)}" -) -@pytest.mark.parametrize( - "is_causal", causal_mask_options, ids=lambda p: f"causal{int(p)}" -) -@pytest.mark.parametrize( - "is_padding", padding_mask_options, ids=lambda p: f"padding{int(p)}" -) +@pytest.mark.parametrize("is_dropout", dropout_options, ids=lambda p: f"dropout{int(p)}") +@pytest.mark.parametrize("is_sliding_window", sliding_window_mask_options, ids=lambda p: f"sliding_window{int(p)}") +@pytest.mark.parametrize("is_causal_bottom_right", causal_mask_bottom_right_options, ids=lambda p: f"causal_bottom_right{int(p)}") +@pytest.mark.parametrize("is_causal", causal_mask_options, ids=lambda p: f"causal{int(p)}") +@pytest.mark.parametrize("is_padding", padding_mask_options, ids=lambda p: f"padding{int(p)}") @pytest.mark.parametrize("is_alibi", alibi_mask_options, ids=lambda p: f"alibi{int(p)}") @pytest.mark.parametrize("is_bias", bias_options, ids=lambda p: f"bias{int(p)}") @pytest.mark.parametrize("head_group", head_group_options) @pytest.mark.parametrize("layout", layout_options) @pytest.mark.parametrize("input_type", input_type_options, ids=lambda p: str(p)) +# fmt: on @torch_fork_set_rng(seed=0) def test_sdpa( input_type, @@ -309,9 +442,12 @@ def test_sdpa( is_alibi, is_padding, is_causal, + is_causal_bottom_right, + is_sliding_window, is_dropout, is_ragged, is_infer, + request, arg_params, ): @@ -335,11 +471,14 @@ def test_sdpa( if is_ragged and cudnn_version < "9": pytest.skip("Ragged tensor is only supported 9.0.0 onwards") + if is_ragged and layout == "bs3hd" and cudnn_version < "9.1.0": + pytest.skip("t3hd is only supported on 9.1.0 onwards") + if is_ragged and torch.cuda.get_device_capability()[0] < 9: pytest.skip("Ragged tensor is only supported hopper") - if is_ragged and layout != "non_interleaved": - pytest.skip("Ragged tensor is only tested with non-interleaved bshd layout") + if is_ragged and not (layout == "bshd_bshd_bshd" or layout == "bs3hd"): + pytest.skip("Ragged tensor is only tested with thd_thd_thd and t3hd") if is_ragged and not is_padding: pytest.skip("Ragged tensor is only tested with packed variable length tensors") @@ -352,7 +491,7 @@ def test_sdpa( # key+value sequence length s_kv = ( random.choice([8, 16, 24, 32, 256, 512, 1024, 2048]) - if layout == "non_interleaved" + if layout == "bshd_bshd_bshd" else s_q ) # query+key embedding dimension per head @@ -360,7 +499,7 @@ def test_sdpa( # value embedding dimension per head d_v = ( random.choice([64, 96, 128]) - if (layout == "non_interleaved" and not is_ragged) + if (layout == "bshd_bshd_bshd" and not is_ragged) else d_qk ) # number of heads @@ -370,7 +509,7 @@ def test_sdpa( h_v = 6 elif head_group == "group_query": h_k = random.choice([6, 3, 2, 1]) - h_v = random.choice([6, 3, 2, 1]) if layout == "non_interleaved" else h_k + h_v = random.choice([6, 3, 2, 1]) if layout == "bshd_bshd_bshd" else h_k elif head_group == "multi_query": h_k = 1 h_v = 1 @@ -381,6 +520,8 @@ def test_sdpa( b = int(arg_params.mha_b) if arg_params.mha_b != None else b s_q = int(arg_params.mha_s_q) if arg_params.mha_s_q != None else s_q s_kv = int(arg_params.mha_s_kv) if arg_params.mha_s_kv != None else s_kv + if is_sliding_window: + s_kv = s_q d_qk = int(arg_params.mha_d_qk) if arg_params.mha_d_qk != None else d_qk d_v = int(arg_params.mha_d_v) if arg_params.mha_d_v != None else d_v h_q = int(arg_params.mha_h_q) if arg_params.mha_h_q != None else h_q @@ -390,29 +531,14 @@ def test_sdpa( if d_qk != d_v and cudnn_version < "8.9.6": pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") - if cudnn_version < "9": - if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and (is_padding or is_dropout): - pytest.skip( - "s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0" - ) - - if cudnn_version < "8.9.6": - pytest.skip( - "d not a multiple of 64, not-multiple-of-64 seq_kv is not supported below 8.9.6" - ) - - if (d_qk % 64 != 0) and cudnn_version < "8.9.6": - pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - - if (d_qk % 64 != 0) and cudnn_version < "8.9.6": - pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - if d_qk != d_v and is_ragged and cudnn_version < "9.1": pytest.skip("d_qk != d_v is not supported with ragged offset") + print("\n=============== TEST CMD TO REPRODUCE ===============") print( - f"--mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}" + f"pytest {request.node.nodeid} --mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}" ) + print("=====================================================") attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 @@ -452,9 +578,13 @@ def test_sdpa( else None ) seq_len_kv_gpu = ( - torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None + ( + torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") + if is_padding + else None + ) + if not (layout == "bs3hd" and head_group == "multi_head") + else seq_len_q_gpu ) if is_dropout: @@ -467,26 +597,22 @@ def test_sdpa( else None ) - q_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_q_gpu) * h_q * d_qk).int() - if is_ragged - else None - ) - k_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_kv_gpu) * h_k * d_qk).int() - if is_ragged - else None - ) - v_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_kv_gpu) * h_v * d_v).int() - if is_ragged - else None - ) - o_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_q_gpu) * h_q * d_v).int() - if is_ragged - else None - ) + if is_ragged: + ( + q_ragged_offset_gpu, + k_ragged_offset_gpu, + v_ragged_offset_gpu, + o_ragged_offset_gpu, + ) = generate_ragged_offset( + layout, + head_group, + shape_q, + shape_k, + shape_v, + shape_o, + seq_len_q_gpu, + seq_len_kv_gpu, + ) o_gpu = torch.empty( b * h_q * s_q * d_v, dtype=input_type, device="cuda" @@ -498,7 +624,7 @@ def test_sdpa( ) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # cuDNN graph @@ -535,6 +661,10 @@ def test_sdpa( k.set_ragged_offset(k_ragged_offset) v.set_ragged_offset(v_ragged_offset) + sliding_window_length = None + if is_sliding_window: + sliding_window_length = s_kv // 4 + o, stats = graph.sdpa( name="sdpa", q=q, @@ -548,6 +678,8 @@ def test_sdpa( seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, use_causal_mask=is_causal, + use_causal_mask_bottom_right=is_causal_bottom_right, + sliding_window_length=sliding_window_length, dropout=dropout_tuple if is_dropout else None, rng_dump=rng_dump, ) @@ -559,7 +691,15 @@ def test_sdpa( if is_infer == False: stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - graph.validate() + try: + graph.validate() + except cudnn.cudnnGraphNotSupportedError as e: + cudnn.destroy_handle(handle) + pytest.xfail(repr(e)) + except Exception as e: + cudnn.destroy_handle(handle) + pytest.fail(repr(e)) + graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() @@ -572,10 +712,10 @@ def test_sdpa( bias: bias_gpu, seq_len_q: seq_len_q_gpu, seq_len_kv: seq_len_kv_gpu, - q_ragged_offset: q_ragged_offset_gpu, - k_ragged_offset: k_ragged_offset_gpu, - v_ragged_offset: v_ragged_offset_gpu, - o_ragged_offset: o_ragged_offset_gpu, + q_ragged_offset: q_ragged_offset_gpu if is_ragged else None, + k_ragged_offset: k_ragged_offset_gpu if is_ragged else None, + v_ragged_offset: v_ragged_offset_gpu if is_ragged else None, + o_ragged_offset: o_ragged_offset_gpu if is_ragged else None, o: o_gpu, stats: stats_gpu, rng_dump: rng_dump_gpu, @@ -592,24 +732,24 @@ def test_sdpa( torch.cuda.synchronize() # compare with torch autograd reference - q_ref = q_gpu.detach().float() - k_ref = k_gpu.detach().float() - v_ref = v_gpu.detach().float() + q_ref = q_gpu.float() + k_ref = k_gpu.float() + v_ref = v_gpu.float() if is_ragged: - q_ref = convert_ragged_to_uniform(q_ref, q_ragged_offset_gpu.detach()) - k_ref = convert_ragged_to_uniform(k_ref, k_ragged_offset_gpu.detach()) - v_ref = convert_ragged_to_uniform(v_ref, v_ragged_offset_gpu.detach()) + q_ref = convert_ragged_to_uniform(q_ref, seq_len_q_gpu.detach()) + k_ref = convert_ragged_to_uniform(k_ref, seq_len_kv_gpu.detach()) + v_ref = convert_ragged_to_uniform(v_ref, seq_len_kv_gpu.detach()) if is_bias: - bias_ref = bias_gpu.detach().float() + bias_ref = bias_gpu.float() if is_padding: - seq_len_q_ref = seq_len_q_gpu.detach().flatten() - seq_len_kv_ref = seq_len_kv_gpu.detach().flatten() + seq_len_q_ref = seq_len_q_gpu.flatten() + seq_len_kv_ref = seq_len_kv_gpu.flatten() if is_dropout: - rng_dump_ref = rng_dump_gpu.detach().float() + rng_dump_ref = rng_dump_gpu.float() ret = compute_ref( q_ref, @@ -620,6 +760,8 @@ def test_sdpa( is_alibi=is_alibi, padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None, is_causal=is_causal, + is_causal_bottom_right=is_causal_bottom_right, + sliding_window_length=sliding_window_length, compute_stats=(is_infer == False), dropout_prob=dropout_prob, dropout_mask=rng_dump_ref if is_dropout else None, @@ -630,7 +772,7 @@ def test_sdpa( o_ref = ret if is_ragged: - o_gpu = convert_ragged_to_uniform(o_gpu, o_ragged_offset_gpu.detach()) + o_gpu = convert_ragged_to_uniform(o_gpu, seq_len_q_gpu.detach()) if is_padding: # zero out padded region of the output for comparison @@ -645,22 +787,22 @@ def test_sdpa( if is_infer == False: torch.testing.assert_close(stats_ref, stats_gpu, atol=2e-2, rtol=2e-2) + cudnn.destroy_handle(handle) + +# fmt: off @pytest.mark.parametrize("is_ragged", ragged_options, ids=lambda p: f"ragged{int(p)}") -@pytest.mark.parametrize( - "is_dropout", dropout_options, ids=lambda p: f"dropout{int(p)}" -) -@pytest.mark.parametrize( - "is_causal", causal_mask_options, ids=lambda p: f"causal{int(p)}" -) -@pytest.mark.parametrize( - "is_padding", padding_mask_options, ids=lambda p: f"padding{int(p)}" -) +@pytest.mark.parametrize("is_dropout", dropout_options, ids=lambda p: f"dropout{int(p)}") +@pytest.mark.parametrize("is_sliding_window", sliding_window_mask_options, ids=lambda p: f"sliding_window{int(p)}") +@pytest.mark.parametrize("is_causal_bottom_right", causal_mask_bottom_right_options, ids=lambda p: f"causal_bottom_right{int(p)}") +@pytest.mark.parametrize("is_causal", causal_mask_options, ids=lambda p: f"causal{int(p)}") +@pytest.mark.parametrize("is_padding", padding_mask_options, ids=lambda p: f"padding{int(p)}") @pytest.mark.parametrize("is_alibi", alibi_mask_options, ids=lambda p: f"alibi{int(p)}") @pytest.mark.parametrize("is_bias", bias_options, ids=lambda p: f"bias{int(p)}") @pytest.mark.parametrize("head_group", head_group_options) @pytest.mark.parametrize("layout", layout_options) @pytest.mark.parametrize("input_type", input_type_options, ids=lambda p: str(p)) +# fmt: on @torch_fork_set_rng(seed=0) def test_sdpa_backward( input_type, @@ -670,8 +812,11 @@ def test_sdpa_backward( is_alibi, is_padding, is_causal, + is_causal_bottom_right, + is_sliding_window, is_dropout, is_ragged, + request, arg_params, ): @@ -686,7 +831,7 @@ def test_sdpa_backward( if is_bias and cudnn_version < "8.9.6": pytest.skip("dBias is only supported 8.9.6 onwards.") - if is_bias and torch.cuda.get_device_capability()[0] < 9: + if is_bias and cudnn_version < "9" and torch.cuda.get_device_capability()[0] < 9: pytest.skip("dBias is only supported on hopper onwards.") if is_bias and is_padding: @@ -710,18 +855,18 @@ def test_sdpa_backward( if is_ragged and torch.cuda.get_device_capability()[0] < 9: pytest.skip("Ragged tensor is only supported hopper") - if is_ragged and layout != "non_interleaved": - pytest.skip("Ragged tensor is only tested with non-interleaved bshd layout") + if is_ragged and not (layout == "bshd_bshd_bshd" or layout == "bs3hd"): + pytest.skip("Ragged tensor is only tested with thd_thd_thd and t3hd") if is_ragged and head_group != "multi_head": pytest.skip("Ragged offset is only supported with multi_head") + if is_ragged and layout == "bs3hd" and cudnn_version < "9.1.0": + pytest.skip("t3hd is only supported on 9.1.0 onwards") + if is_ragged and not is_padding: pytest.skip("Ragged tensor is only tested with packed variable length tensors") - # test both dP workspace optimization by lowering dP workspace limit to 8MB - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = str(8 * 1024 * 1024) - # -------------------------- default randomized parameter testing ------------------------ # batch size b = 2 @@ -730,7 +875,7 @@ def test_sdpa_backward( # key+value sequence length s_kv = ( random.choice([8, 16, 24, 32, 256, 512, 1024]) - if layout == "non_interleaved" + if layout == "bshd_bshd_bshd" else s_q ) # query+key embedding dimension per head @@ -738,7 +883,7 @@ def test_sdpa_backward( # value embedding dimension per head d_v = ( random.choice([64, 96, 128]) - if (layout == "non_interleaved" and not is_ragged) + if (layout == "bshd_bshd_bshd" and not is_ragged) else d_qk ) # number of heads @@ -748,38 +893,17 @@ def test_sdpa_backward( h_v = 6 elif head_group == "group_query": h_k = random.choice([6, 3, 2, 1]) - h_v = random.choice([6, 3, 2, 1]) if layout == "non_interleaved" else h_k + h_v = random.choice([6, 3, 2, 1]) if layout == "bshd_bshd_bshd" else h_k elif head_group == "multi_query": h_k = 1 h_v = 1 else: assert False, "Head group must be either MHA, GQA, or MQA" - if d_qk != d_v and cudnn_version < "8.9.6": - pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") - + # test both deterministic and nondeterministic implementation if cudnn_version < "9": - if s_q < 64: - pytest.skip("s_q less than 64 is not supported before cudnn 9.0.0") - - if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and (is_padding or is_dropout): - pytest.skip( - "s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0" - ) - - if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and is_bias: - pytest.skip( - "cudnn backend does not support bias with non-64-aligned seq_q or seq_kv." - ) - - if (s_kv % 64 != 0) and cudnn_version < "8.9.6": - pytest.skip("not-multiple-of-64 seq_kv is not supported below 8.9.6") - - if (d_qk % 64 != 0) and cudnn_version < "8.9.6": - pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - - if d_qk != d_v and is_ragged and cudnn_version < "9.1": - pytest.skip("d_qk != d_v is not supported with ragged offset") + os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" + is_deterministic = random.choice([True, False]) # -------------------------- override test parameters if args are provided ---------------- b = int(arg_params.mha_b) if arg_params.mha_b != None else b @@ -790,10 +914,35 @@ def test_sdpa_backward( h_q = int(arg_params.mha_h_q) if arg_params.mha_h_q != None else h_q h_k = int(arg_params.mha_h_k) if arg_params.mha_h_k != None else h_k h_v = int(arg_params.mha_h_v) if arg_params.mha_h_v != None else h_v + is_deterministic = ( + bool(int(arg_params.mha_deterministic)) + if arg_params.mha_deterministic != None + else is_deterministic + ) + + if d_qk != d_v and cudnn_version < "8.9.6": + pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") + + if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and is_bias: + pytest.skip( + "cudnn backend does not support bias with non-64-aligned seq_q or seq_kv." + ) + if d_qk != d_v and is_ragged and cudnn_version < "9.1": + pytest.skip("d_qk != d_v is not supported with ragged offset") + + if ( + is_deterministic + and cudnn_version < "9" + and torch.cuda.get_device_capability()[0] < 9 + ): + pytest.skip("Ampere deterministic implementation is not supported below 9.0.0") + + print("\n=============== TEST CMD TO REPRODUCE ===============") print( - f"--mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}" + f"pytest {request.node.nodeid} --mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v} --mha_deterministic={int(is_deterministic)}" ) + print("=====================================================") attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 @@ -847,9 +996,13 @@ def test_sdpa_backward( else None ) seq_len_kv_gpu = ( - torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None + ( + torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") + if is_padding + else None + ) + if not (layout == "bs3hd" and head_group == "multi_head") + else seq_len_q_gpu ) if is_dropout: @@ -862,26 +1015,22 @@ def test_sdpa_backward( else None ) - q_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_q_gpu) * h_q * d_qk).int() - if is_ragged - else None - ) - k_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_kv_gpu) * h_k * d_qk).int() - if is_ragged - else None - ) - v_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_kv_gpu) * h_v * d_v).int() - if is_ragged - else None - ) - o_ragged_offset_gpu = ( - (compute_exclusive_prefix_sum(seq_len_q_gpu) * h_q * d_v).int() - if is_ragged - else None - ) + if is_ragged: + ( + q_ragged_offset_gpu, + k_ragged_offset_gpu, + v_ragged_offset_gpu, + o_ragged_offset_gpu, + ) = generate_ragged_offset( + layout, + head_group, + shape_q, + shape_k, + shape_v, + shape_o, + seq_len_q_gpu, + seq_len_kv_gpu, + ) o_gpu = torch.empty( b * h_q * s_q * d_v, dtype=input_type, device="cuda" @@ -889,7 +1038,7 @@ def test_sdpa_backward( stats_gpu = torch.empty(b, h_q, s_q, 1, dtype=torch.float32, device="cuda") handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # forward cuDNN graph @@ -926,6 +1075,10 @@ def test_sdpa_backward( k.set_ragged_offset(k_ragged_offset) v.set_ragged_offset(v_ragged_offset) + sliding_window_length = None + if is_sliding_window: + sliding_window_length = s_kv // 4 + o, stats = graph.sdpa( name="sdpa", q=q, @@ -939,6 +1092,8 @@ def test_sdpa_backward( seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, use_causal_mask=is_causal, + use_causal_mask_bottom_right=is_causal_bottom_right, + sliding_window_length=sliding_window_length, dropout=dropout_tuple if is_dropout else None, rng_dump=rng_dump, ) @@ -949,7 +1104,15 @@ def test_sdpa_backward( stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - graph.validate() + try: + graph.validate() + except cudnn.cudnnGraphNotSupportedError as e: + cudnn.destroy_handle(handle) + pytest.xfail(repr(e)) + except Exception as e: + cudnn.destroy_handle(handle) + pytest.fail(repr(e)) + graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() @@ -962,10 +1125,10 @@ def test_sdpa_backward( bias: bias_gpu, seq_len_q: seq_len_q_gpu, seq_len_kv: seq_len_kv_gpu, - q_ragged_offset: q_ragged_offset_gpu, - k_ragged_offset: k_ragged_offset_gpu, - v_ragged_offset: v_ragged_offset_gpu, - o_ragged_offset: o_ragged_offset_gpu, + q_ragged_offset: q_ragged_offset_gpu if is_ragged else None, + k_ragged_offset: k_ragged_offset_gpu if is_ragged else None, + v_ragged_offset: v_ragged_offset_gpu if is_ragged else None, + o_ragged_offset: o_ragged_offset_gpu if is_ragged else None, o: o_gpu, stats: stats_gpu, rng_dump: rng_dump_gpu, @@ -988,7 +1151,7 @@ def test_sdpa_backward( stats_gpu[i, :, m:, :] = 0 handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) # backward cuDNN graph @@ -1049,7 +1212,10 @@ def test_sdpa_backward( seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, use_causal_mask=is_causal, + use_causal_mask_bottom_right=is_causal_bottom_right, + sliding_window_length=sliding_window_length, dropout=dropout_tuple if is_dropout else None, + use_deterministic_algorithm=is_deterministic, ) dQ.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) @@ -1060,7 +1226,15 @@ def test_sdpa_backward( dK.set_ragged_offset(k_ragged_offset) dV.set_ragged_offset(v_ragged_offset) - graph.validate() + try: + graph.validate() + except cudnn.cudnnGraphNotSupportedError as e: + cudnn.destroy_handle(handle) + pytest.xfail(repr(e)) + except Exception as e: + cudnn.destroy_handle(handle) + pytest.fail(repr(e)) + graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() @@ -1080,10 +1254,10 @@ def test_sdpa_backward( dBias: dBias_gpu, seq_len_q: seq_len_q_gpu, seq_len_kv: seq_len_kv_gpu, - q_ragged_offset: q_ragged_offset_gpu, - k_ragged_offset: k_ragged_offset_gpu, - v_ragged_offset: v_ragged_offset_gpu, - o_ragged_offset: o_ragged_offset_gpu, + q_ragged_offset: q_ragged_offset_gpu if is_ragged else None, + k_ragged_offset: k_ragged_offset_gpu if is_ragged else None, + v_ragged_offset: v_ragged_offset_gpu if is_ragged else None, + o_ragged_offset: o_ragged_offset_gpu if is_ragged else None, } if is_dropout: @@ -1097,23 +1271,19 @@ def test_sdpa_backward( torch.cuda.synchronize() # compare with torch autograd reference - q_ref = q_gpu.detach().float() - q_ref.requires_grad = True - k_ref = k_gpu.detach().float() - k_ref.requires_grad = True - v_ref = v_gpu.detach().float() - v_ref.requires_grad = True + q_ref = q_gpu.detach().float().requires_grad_() + k_ref = k_gpu.detach().float().requires_grad_() + v_ref = v_gpu.detach().float().requires_grad_() dO_ref = dO_gpu.detach().float() if is_ragged: - q_ref = convert_ragged_to_uniform(q_ref, q_ragged_offset_gpu.detach()) - k_ref = convert_ragged_to_uniform(k_ref, k_ragged_offset_gpu.detach()) - v_ref = convert_ragged_to_uniform(v_ref, v_ragged_offset_gpu.detach()) - dO_ref = convert_ragged_to_uniform(dO_ref, o_ragged_offset_gpu.detach()) + q_ref = convert_ragged_to_uniform(q_ref, seq_len_q_gpu.detach()) + k_ref = convert_ragged_to_uniform(k_ref, seq_len_kv_gpu.detach()) + v_ref = convert_ragged_to_uniform(v_ref, seq_len_kv_gpu.detach()) + dO_ref = convert_ragged_to_uniform(dO_ref, seq_len_q_gpu.detach()) if is_bias: - bias_ref = bias_gpu.detach().float() - bias_ref.requires_grad = True + bias_ref = bias_gpu.detach().float().requires_grad_() if is_padding: seq_len_q_ref = seq_len_q_gpu.detach().flatten() @@ -1131,6 +1301,8 @@ def test_sdpa_backward( is_alibi=is_alibi, padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None, is_causal=is_causal, + is_causal_bottom_right=is_causal_bottom_right, + sliding_window_length=sliding_window_length, dropout_prob=dropout_prob, dropout_mask=rng_dump_ref if is_dropout else None, compute_stats=False, @@ -1150,9 +1322,9 @@ def test_sdpa_backward( dBias_ref = opt_refs.pop(0) if is_ragged: - dQ_gpu = convert_ragged_to_uniform(dQ_gpu, q_ragged_offset_gpu.detach()) - dK_gpu = convert_ragged_to_uniform(dK_gpu, k_ragged_offset_gpu.detach()) - dV_gpu = convert_ragged_to_uniform(dV_gpu, v_ragged_offset_gpu.detach()) + dQ_gpu = convert_ragged_to_uniform(dQ_gpu, seq_len_q_gpu.detach()) + dK_gpu = convert_ragged_to_uniform(dK_gpu, seq_len_kv_gpu.detach()) + dV_gpu = convert_ragged_to_uniform(dV_gpu, seq_len_kv_gpu.detach()) if is_padding: # zero out padded region of the output for comparison @@ -1167,6 +1339,8 @@ def test_sdpa_backward( dBias_ref[i, :, m:, :] = 0 dBias_ref[i, :, :, n:] = 0 + torch.cuda.synchronize() + torch.testing.assert_close(dQ_ref, dQ_gpu, check_dtype=False, atol=2e-2, rtol=2e-2) torch.testing.assert_close( dK_ref, @@ -1186,6 +1360,7 @@ def test_sdpa_backward( torch.testing.assert_close( dBias_ref, dBias_gpu, check_dtype=False, atol=2e-2, rtol=2e-2 ) + cudnn.destroy_handle(handle) if __name__ == "__main__": @@ -1193,7 +1368,7 @@ def test_sdpa_backward( # ================== forward ================== """ pytest \ - test/python_fe/test_mhas.py::test_sdpa[torch.float16-non_interleaved-group_query-bias0-alibi0-padding0-causal0-dropout0-ragged0-infer0] \ + test/python_fe/test_mhas.py::test_sdpa[torch.float16-bshd_bshd_bshd-group_query-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0-infer0] \ -s \ --mha_b 3 \ --mha_s_q 256 \ @@ -1202,12 +1377,13 @@ def test_sdpa_backward( --mha_d_v 32 \ --mha_h_q 12 \ --mha_h_k 3 \ - --mha_h_v 4 + --mha_h_v 4 \ + --mha_deterministic 0 """ # ================== backward ================== """ pytest \ - test/python_fe/test_mhas.py::test_sdpa_backward[torch.float16-non_interleaved-group_query-bias0-alibi0-padding0-causal0-dropout0-ragged0] \ + test/python_fe/test_mhas.py::test_sdpa_backward[torch.float16-bshd_bshd_bshd-group_query-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0] \ -s \ --mha_b 3 \ --mha_s_q 256 \ @@ -1216,7 +1392,8 @@ def test_sdpa_backward( --mha_d_v 32 \ --mha_h_q 12 \ --mha_h_k 3 \ - --mha_h_v 4 + --mha_h_v 4 \ + --mha_deterministic 0 """ pytest.main([__file__]) diff --git a/test/python_fe/test_rmsnorm.py b/test/python_fe/test_rmsnorm.py index 001a8f5..009c772 100644 --- a/test/python_fe/test_rmsnorm.py +++ b/test/python_fe/test_rmsnorm.py @@ -94,7 +94,7 @@ def test_rmsnorm(param_extract): print("Building cudnn graph") handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -147,10 +147,12 @@ def test_rmsnorm(param_extract): handle=handle, ) + torch.cuda.synchronize() print("Comparing with reference") torch.testing.assert_close(Y_expected, Y_actual, atol=0.03125, rtol=0.03125) torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=0.005, rtol=0.005) print("Success!!") + cudnn.destroy_handle(handle) target = torch.randn_like(Y_expected) criterion = nn.MSELoss() @@ -164,7 +166,7 @@ def test_rmsnorm(param_extract): loss.backward() handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) bwd_graph = cudnn.pygraph( @@ -223,12 +225,14 @@ def test_rmsnorm(param_extract): handle=handle, ) + torch.cuda.synchronize() print("Comparing with reference") torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-4, rtol=2e-4) torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=5e-4, rtol=5e-4) if has_bias: torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=5e-4, rtol=5e-4) print("Success!!") + cudnn.destroy_handle(handle) if __name__ == "__main__": diff --git a/test/python_fe/test_silu_and_mul.py b/test/python_fe/test_silu_and_mul.py new file mode 100644 index 0000000..1afd181 --- /dev/null +++ b/test/python_fe/test_silu_and_mul.py @@ -0,0 +1,238 @@ +import cudnn +from looseversion import LooseVersion +import pytest + +import torch +from torch.profiler import profile, record_function, ProfilerActivity + + +@pytest.mark.skipif( + LooseVersion(cudnn.backend_version_string()) < "9.3", + reason="Reduction mul is not supported below cudnn 9.3", +) +@pytest.mark.skipif( + hasattr(torch, "float8_e4m3fn") is False, + reason="torch does not have fp8 data types", +) +def test_gemm_silu_and_mul(): + + # setup + M = 64 + N = 64 + K = 64 + + # cudnn graph + handle = cudnn.create_handle() + graph = cudnn.pygraph( + handle=handle, + name="cudnn_graph_0", + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + X_gpu = torch.randint(-8, 8, (1, M, K), requires_grad=False, device="cuda").to( + dtype=torch.float8_e4m3fn + ) + W_gpu = torch.randint(-8, 8, (2, K, N), requires_grad=False, device="cuda").to( + dtype=torch.float8_e4m3fn + ) + C_gpu = torch.zeros(1, M, N, requires_grad=False, device="cuda").to( + dtype=torch.float + ) + + scale = 0.5 + X_DQ_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") + W_DQ_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") + C_Q_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") + B_mask_gpu = torch.tensor([[[1]], [[0]]], dtype=torch.int32, device="cuda") + + X = graph.tensor( + name="X", + dim=X_gpu.size(), + stride=X_gpu.stride(), + data_type=cudnn.data_type.FP8_E4M3, + ) + W = graph.tensor( + name="W", + dim=W_gpu.size(), + stride=W_gpu.stride(), + data_type=cudnn.data_type.FP8_E4M3, + ) + C0 = graph.matmul(X, W) + + X_DQ = graph.tensor( + name="X_DQ", + dim=X_DQ_cpu.size(), + stride=X_DQ_cpu.stride(), + data_type=cudnn.data_type.FLOAT, + is_pass_by_value=True, + ) + C1 = graph.mul(C0, X_DQ) + + W_DQ = graph.tensor( + name="W_DQ", + dim=W_DQ_cpu.size(), + stride=W_DQ_cpu.stride(), + data_type=cudnn.data_type.FLOAT, + is_pass_by_value=True, + ) + C2 = graph.mul(C1, W_DQ) + + C3 = graph.mul(graph.sigmoid(C2), C2) + + B_mask = graph.tensor( + name="B_mask", + dim=B_mask_gpu.size(), + stride=B_mask_gpu.stride(), + data_type=cudnn.data_type.INT32, + ) + C_combined = graph.binary_select(C2, C3, B_mask) + + C = graph.reduction(C_combined, mode=cudnn.reduction_mode.MUL) + C.set_dim([1, M, N]).set_stride([M * N, N, 1]).set_output(True).set_data_type( + cudnn.data_type.FLOAT + ) + + # The output of reductino operation has to be fp32. + # Plus, the data is in global memory so its not possible to fuse anything now. + # C_Q = graph.tensor( + # name="C_Q", + # dim=C_Q_cpu.size(), + # stride=C_Q_cpu.stride(), + # data_type=cudnn.data_type.FLOAT, + # is_pass_by_value=True, + # ) + # C_fp8 = graph.mul(C, C_Q) + # C_fp8.set_output(True) + + try: + graph.build([cudnn.heur_mode.A]) + except cudnn.cudnnGraphNotSupportedError as e: + cudnn.destroy_handle(handle) + pytest.xfail(repr(e)) + except Exception as e: + cudnn.destroy_handle(handle) + pytest.fail(repr(e)) + + workspace = torch.empty( + graph.get_workspace_size(), device="cuda", dtype=torch.uint8 + ) + + with profile(activities=[ProfilerActivity.CUDA]) as prof: + graph.execute( + { + X: X_gpu, + W: W_gpu, + X_DQ: X_DQ_cpu, + W_DQ: W_DQ_cpu, + B_mask: B_mask_gpu, + C: C_gpu, + }, + workspace, + handle=handle, + ) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # Compare + torch.cuda.synchronize() + + cudnn.destroy_handle(handle) + + +@pytest.mark.skipif( + hasattr(torch, "float8_e4m3fn") is False, + reason="torch does not have fp8 data types", +) +def test_silu_and_mul_and_quantization(): + + # setup + M = 64 + N = 64 + + # cudnn graph + handle = cudnn.create_handle() + graph = cudnn.pygraph( + handle=handle, + name="cudnn_graph_0", + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + C2a_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to( + dtype=torch.float8_e4m3fn + ) + C2b_gpu = torch.randint(-8, 8, (1, M, N), requires_grad=False, device="cuda").to( + dtype=torch.float8_e4m3fn + ) + C_gpu = torch.empty(1, M, N, requires_grad=False, device="cuda").to( + dtype=torch.float8_e4m3fn + ) + + scale = 0.5 + C2_DQ_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") + C_Q_cpu = torch.full((1, 1, 1), scale, dtype=torch.float32, device="cpu") + + C2a = graph.tensor( + name="C2a", + dim=C2a_gpu.size(), + stride=C2a_gpu.stride(), + data_type=cudnn.data_type.FP8_E4M3, + ) + C2b = graph.tensor( + name="C2b", + dim=C2b_gpu.size(), + stride=C2b_gpu.stride(), + data_type=cudnn.data_type.FP8_E4M3, + ) + + C2_DQ = graph.tensor( + name="C2_DQ", + dim=C2_DQ_cpu.size(), + stride=C2_DQ_cpu.stride(), + data_type=cudnn.data_type.FLOAT, + is_pass_by_value=True, + ) + C2a_fp32 = graph.mul(C2a, C2_DQ) + C2b_fp32 = graph.mul(C2b, C2_DQ) + + C3 = graph.mul(graph.sigmoid(C2b_fp32), C2b_fp32) + + C_fp32 = graph.mul(C2a_fp32, C3) + C_Q = graph.tensor( + name="C_Q", + dim=C_Q_cpu.size(), + stride=C_Q_cpu.stride(), + data_type=cudnn.data_type.FLOAT, + is_pass_by_value=True, + ) + C_fp8 = graph.mul(C_fp32, C_Q) + C_fp8.set_output(True).set_data_type(cudnn.data_type.FP8_E4M3) + + graph.build([cudnn.heur_mode.A]) + workspace = torch.empty( + graph.get_workspace_size(), device="cuda", dtype=torch.uint8 + ) + + with profile(activities=[ProfilerActivity.CUDA]) as prof: + graph.execute( + { + C2a: C2a_gpu, + C2b: C2b_gpu, + C2_DQ: C2_DQ_cpu, + C_Q: C_Q_cpu, + C_fp8: C_gpu, + }, + workspace, + handle=handle, + ) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # Compare + torch.cuda.synchronize() + + cudnn.destroy_handle(handle) + + +if __name__ == "__main__": + test_silu_and_mul_and_quantization() + test_gemm_silu_and_mul() diff --git a/test/python_fe/test_utils.py b/test/python_fe/test_utils.py index 1d9e3b7..3af0975 100644 --- a/test/python_fe/test_utils.py +++ b/test/python_fe/test_utils.py @@ -7,7 +7,7 @@ def torch_fork_set_rng(seed=None): def decorator_(func): @functools.wraps(func) def wrapper_(*args, **kwargs): - with torch.random.fork_rng(): + with torch.random.fork_rng(devices=range(torch.cuda.device_count())): if seed is not None: torch.manual_seed(seed) return func(*args, **kwargs) diff --git a/test/python_fe/test_wgrads.py b/test/python_fe/test_wgrads.py index fa39ad0..e861439 100644 --- a/test/python_fe/test_wgrads.py +++ b/test/python_fe/test_wgrads.py @@ -61,7 +61,7 @@ def test_scale_bias_relu_wgrad(): ).to(memory_format=torch.channels_last) handle = cudnn.create_handle() - stream = torch.cuda.Stream().cuda_stream + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=handle, stream=stream) graph = cudnn.pygraph( @@ -119,6 +119,9 @@ def test_scale_bias_relu_wgrad(): handle=handle, ) + torch.cuda.synchronize() + cudnn.destroy_handle(handle) + except cudnn.cudnnGraphNotSupportedError as ex: print(ex) diff --git a/test/unit_tests/serialize.cpp b/test/unit_tests/serialize.cpp index 6b1e515..8fedd33 100644 --- a/test/unit_tests/serialize.cpp +++ b/test/unit_tests/serialize.cpp @@ -69,7 +69,7 @@ TEST_CASE("Conv fprop attributes", "[conv_fprop][serialize]") { REQUIRE(conv_fprop_attributes_deserialized == conv_fprop_attributes); } -TEST_CASE("Graph key", "[serialize]") { +TEST_CASE("Graph key", "[graph][key]") { namespace fe = cudnn_frontend; fe::graph::Graph graph; @@ -117,6 +117,95 @@ TEST_CASE("Graph key", "[serialize]") { REQUIRE(key == graph.key()); } +TEST_CASE("Matmul fp8 fusion", "[graph][serialize]") { + namespace fe = cudnn_frontend; + // matmul problem size + int64_t const b = 16; + int64_t const m = 32; + int64_t const n = 64; + int64_t const k = 128; + + fe::graph::Graph graph{}; + + // Create the two non-virtual input tensors A and B. + // There are read from global memory. + auto A_attributes = fe::graph::Tensor_attributes() + .set_name("A") + .set_dim({b, m, k}) + .set_stride({m * k, k, 1}) + .set_data_type(fe::DataType_t::FP8_E4M3); + auto A = graph.tensor(A_attributes); + + auto B_attributes = fe::graph::Tensor_attributes() + .set_name("B") + .set_dim({b, k, n}) + .set_stride({k * n, 1, k}) + .set_data_type(fe::DataType_t::FP8_E4M3); + auto B = graph.tensor(B_attributes); + + auto A_descale_attributes = fe::graph::Tensor_attributes() + .set_name("descale0") + .set_dim({1, 1, 1}) + .set_stride({1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + auto B_descale_attributes = fe::graph::Tensor_attributes() + .set_name("descale1") + .set_dim({1, 1, 1}) + .set_stride({1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + + auto A_descale = graph.tensor(A_descale_attributes); + auto B_descale = graph.tensor(B_descale_attributes); + + auto matmul_attributes = + // fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + fe::graph::Matmul_attributes().set_name("GEMM").set_compute_data_type(fe::DataType_t::FLOAT); + auto C = graph.matmul(A, B, matmul_attributes); + C->set_data_type(fe::DataType_t::FLOAT); + + // Add scale_A operation + auto pw_0_attributes = fe::graph::Pointwise_attributes() + // .set_name("pw0_Mul") + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto C_after_pw_0 = graph.pointwise(C, A_descale, pw_0_attributes); + C_after_pw_0->set_data_type(fe::DataType_t::FLOAT); + + // Add descale_B operation + auto pw_1_attributes = fe::graph::Pointwise_attributes() + // .set_name("pw1_Mul") + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(fe::DataType_t::FLOAT); + auto C_after_pw_1 = graph.pointwise(C_after_pw_0, B_descale, pw_1_attributes); + C_after_pw_1->set_output(true).set_data_type(fe::DataType_t::BFLOAT16); + + json j = graph; + + std::cout << j << std::endl; + + fe::graph::Graph graph_deserialized; + + REQUIRE(graph_deserialized.deserialize(j).is_good()); + + json j2 = graph_deserialized; + + REQUIRE(j == j2); + + REQUIRE(graph.validate().is_good()); + + std::cout << "Validating deserialized graph" << std::endl; + + cudnnHandle_t handle; // Handle to use during deserialize and execute + + cudnnCreate(&handle); + + REQUIRE(graph_deserialized.validate().is_good()); + + REQUIRE(graph_deserialized.build_operation_graph(handle).is_good()); + + cudnnDestroy(handle); +} + TEST_CASE("conv graph serialization", "[graph][serialize]") { namespace fe = cudnn_frontend; @@ -173,14 +262,17 @@ TEST_CASE("conv graph serialization", "[graph][serialize]") { r->set_output(true).set_data_type(fe::DataType_t::HALF); - REQUIRE(graph.validate().is_good()); - json j = graph; + fe::graph::Graph graph_deserialized; + REQUIRE(graph_deserialized.deserialize(j).is_good()); + json j2 = graph_deserialized; REQUIRE(j == j2); + + REQUIRE(graph_deserialized.validate().is_good()); } TEST_CASE("sdpa graph serialization", "[graph][serialize]") { @@ -261,14 +353,15 @@ TEST_CASE("sdpa graph serialization", "[graph][serialize]") { O->set_output(true).set_dim({b, h, s_q, d}).set_stride({h * d, d, b * h * d, 1}); stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); - REQUIRE(graph.validate().is_good()); - json j = graph; + fe::graph::Graph graph_deserialized; REQUIRE(graph_deserialized.deserialize(j).is_good()); json j2 = graph_deserialized; REQUIRE(j == j2); + + REQUIRE(graph_deserialized.validate().is_good()); } TEST_CASE("sdpa backward graph serialization", "[graph][serialize]") { @@ -335,7 +428,8 @@ TEST_CASE("sdpa backward graph serialization", "[graph][serialize]") { .set_causal_mask(true) .set_attn_scale(attn_scale) .set_bias(bias) - .set_dropout(0.1f, dropout_seed, dropout_offset); + .set_dropout(0.1f, dropout_seed, dropout_offset) + .set_deterministic_algorithm(true); auto [dQ, dK, dV] = graph.sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); @@ -343,12 +437,12 @@ TEST_CASE("sdpa backward graph serialization", "[graph][serialize]") { dK->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); dV->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); - REQUIRE(graph.validate().is_good()); - json j = graph; fe::graph::Graph graph_deserialized; REQUIRE(graph_deserialized.deserialize(j).is_good()); json j2 = graph_deserialized; REQUIRE(j == j2); + + REQUIRE(graph_deserialized.validate().is_good()); } \ No newline at end of file diff --git a/test/unit_tests/validate.cpp b/test/unit_tests/validate.cpp index 050ae66..d2a094d 100644 --- a/test/unit_tests/validate.cpp +++ b/test/unit_tests/validate.cpp @@ -53,10 +53,8 @@ TEST_CASE("Validate conv node", "[conv][validate]") { } TEST_CASE("Move", "[move]") { - namespace fe = cudnn_frontend; - auto validate = [](fe::graph::Graph graph) { - REQUIRE(graph.validate().is_good()); - }; + namespace fe = cudnn_frontend; + auto validate = [](fe::graph::Graph graph) { REQUIRE(graph.validate().is_good()); }; auto construct = []() { fe::graph::Graph graph; REQUIRE(graph.validate().is_good()); diff --git a/tools/json_reproducer/README.md b/tools/json_reproducer/README.md new file mode 100644 index 0000000..2b02326 --- /dev/null +++ b/tools/json_reproducer/README.md @@ -0,0 +1,22 @@ +## Json reproducer + +### Usage + +``` +usage: json_parser.py [-h] -i INPUT_FILE [-v] + +optional arguments: + -h, --help show this help message and exit + -i INPUT_FILE, --input_file INPUT_FILE + Input file name + -v, --verbose Set logging level to max +``` + + +### Notes + +Input is a json representation of graph before validate is called. + +For c++ users, this json can be generated by calling `std::cout << graph << std::endl;` or `graph.print()`. + +For python users, this is called by calling the `print(graph)` method. \ No newline at end of file diff --git a/tools/json_reproducer/json_parser.py b/tools/json_reproducer/json_parser.py new file mode 100644 index 0000000..6df3d41 --- /dev/null +++ b/tools/json_reproducer/json_parser.py @@ -0,0 +1,42 @@ +import argparse +import os + +# Create an argument parser +parser = argparse.ArgumentParser() + +# Add arguments +parser.add_argument("-i", "--input_file", help="Input file name", required=True) + +parser.add_argument( + "-v", "--verbose", help="Set logging level to max", action="store_true" +) + +# Parse the arguments +args = parser.parse_args() + +with open(args.input_file) as f: + data = f.read().replace("\n", "").replace(" ", "") + +import cudnn + +if args.verbose: + os.environ["CUDNN_LOGLEVEL_DBG"] = "3" +else: + os.environ["CUDNN_LOGLEVEL_DBG"] = "2" + +try: + handle = cudnn.create_handle() + + graph = cudnn.pygraph(handle=handle) + + graph.deserialize(data) + + graph.build([cudnn.heur_mode.A]) + + print("Graph built successfully and can be executed.") + +except Exception as e: + print("[cudnn frontend error]") + print(e) + print("[cudnn backend error]") + print(cudnn.get_last_error_string()) diff --git a/tools/json_reproducer/jsons/graph0.json b/tools/json_reproducer/jsons/graph0.json new file mode 100644 index 0000000..df19a96 --- /dev/null +++ b/tools/json_reproducer/jsons/graph0.json @@ -0,0 +1,92 @@ +{ + "context": { + "compute_data_type": "FLOAT", + "intermediate_data_type": "FLOAT", + "io_data_type": "HALF", + "name": "" + }, + "nodes": [ + { + "compute_data_type": null, + "dilation": [1,1], + "inputs": { + "W": "W", + "X": "X" + }, + "name": "", + "outputs": { + "Y": "::Y" + }, + "post_padding": [0,1], + "pre_padding": [0,1], + "stride": [2,3], + "tag": "CONV_FPROP" + }, + { + "axis": null, + "compute_data_type": null, + "inputs": { + "IN_0": "::Y" + }, + "mode": "RELU_FWD", + "name": "relu", + "outputs": { + "OUT_0": "relu::OUT_0" + }, + "relu_lower_clip": "3F000000", + "relu_lower_clip_slope": null, + "relu_upper_clip": "3F0CCCCD", + "tag": "POINTWISE" + } + ], + "tensors": { + "::Y": { + "data_type": null, + "dim": [], + "is_pass_by_value": false, + "is_virtual": true, + "name": "::Y", + "pass_by_value": null, + "reordering_type": "NONE", + "stride": [], + "uid": 0, + "uid_assigned": false + }, + "W": { + "data_type": "HALF", + "dim": [54,40,3,4], + "is_pass_by_value": false, + "is_virtual": false, + "name": "W", + "pass_by_value": null, + "reordering_type": "NONE", + "stride": [480,1,160,40], + "uid": 0, + "uid_assigned": false + }, + "X": { + "data_type": "HALF", + "dim": [20,40,30,40], + "is_pass_by_value": false, + "is_virtual": false, + "name": "X", + "pass_by_value": null, + "reordering_type": "NONE", + "stride": [48000,1,1600,40], + "uid": 0, + "uid_assigned": false + }, + "relu::OUT_0": { + "data_type": null, + "dim": [], + "is_pass_by_value": false, + "is_virtual": false, + "name": "relu::OUT_0", + "pass_by_value": null, + "reordering_type": "NONE", + "stride": [], + "uid": 0, + "uid_assigned": false + } + } +}