Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python unit tests for bf16 GEMM on PVC #217

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 53 additions & 47 deletions python/cutlass/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _encode(self):

return opts


def get_str(self):
opts = self._encode()

Expand Down Expand Up @@ -212,7 +211,8 @@ def load_operation(self, op_key, extra_funcs):
op_attr = json.loads(op_attr)
if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
module = dpctl.program.create_program_from_spirv(q, cubin_image)
module = dpctl.program.create_program_from_spirv(
q, cubin_image)
kernel = module.get_sycl_kernel(operation_name)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
Expand Down Expand Up @@ -264,9 +264,10 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op

includes_host = includes
if self._is_sycl():
includes_host.extend(["stddef"])
includes_host.extend(["stddef.h"])
else:
includes_host.extend(["builtin_types.h", "device_launch_parameters.h", "cstddef"])
includes_host.extend(
["builtin_types.h", "device_launch_parameters.h", "cstddef"])

for incl in includes:
source_buffer_device += SubstituteTemplate(
Expand Down Expand Up @@ -338,46 +339,50 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op
elif self.backend == "dpcpp":
# Emit code to file
tempfile.tempdir = "./"
temp_cpp = tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".cpp", delete=True)
temp_dump_dir = tempfile.TemporaryDirectory(
prefix="kernel_", suffix="_dpcpp")
ignore_out = tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".o", delete=True)
with open(temp_cpp.name, "w") as file:
file.write(source_buffer_device)

# Compile with DPC++
cmd_template = "clang++ ${options} ${srcfile} -o ${outfile} -fsycl-dump-device-code=${tmpdir}"
values = {
"options": compilation_options.get_str(),
"srcfile": temp_cpp.name,
"outfile": ignore_out.name,
"tmpdir": temp_dump_dir.name
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device,
"./cutlass_python_compilation_device_error.txt")

# Find SPIR-V device code in temporary directory
spv_files = list(pathlib.Path(temp_dump_dir.name).glob("*.spv"))

# When specifying a specific subgroup size, DPC++ currently
# generates multiple SPIR-V files. We create a program from each of
# them to find the one containing the kernel with the correct
# subgroup size.
q = dpctl.SyclQueue(cutlass.sycl_device())
op_name = f"__sycl_kernel_{operation_list[0].name()}"
for f in spv_files:
with open(f, "rb") as spirv_file:
spirv_image = spirv_file.read()
program = dpctl.program.create_program_from_spirv(q, spirv_image)
if not program.has_sycl_kernel(op_name):
continue
spirv_kernel = program.get_sycl_kernel(op_name)
if spirv_kernel.max_sub_group_size == 16:
cubin_image = spirv_image
break
with (
tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".cpp", delete=True) as temp_cpp,
tempfile.TemporaryDirectory(
prefix="kernel_", suffix="_dpcpp") as temp_dump_dir,
tempfile.NamedTemporaryFile(
prefix="kernel_", suffix=".o", delete=True) as ignore_out
):
with open(temp_cpp.name, "w") as file:
file.write(source_buffer_device)

# Compile with DPC++
cmd_template = "clang++ ${options} ${srcfile} -o ${outfile} -fsycl-dump-device-code=${tmpdir}"
values = {
"options": compilation_options.get_str(),
"srcfile": temp_cpp.name,
"outfile": ignore_out.name,
"tmpdir": temp_dump_dir
}
cmd = SubstituteTemplate(cmd_template, values)
compile_with_nvcc(cmd.split(" "), source_buffer_device,
"./cutlass_python_compilation_device_error.txt")

# Find SPIR-V device code in temporary directory
spv_files = list(pathlib.Path(
temp_dump_dir).glob("*.spv"))

# When specifying a specific subgroup size, DPC++ currently
# generates multiple SPIR-V files. We create a program from each of
# them to find the one containing the kernel with the correct
# subgroup size.
q = dpctl.SyclQueue(cutlass.sycl_device())
op_name = f"__sycl_kernel_{operation_list[0].name()}"
for f in spv_files:
with open(f, "rb") as spirv_file:
spirv_image = spirv_file.read()
program = dpctl.program.create_program_from_spirv(
q, spirv_image)
if not program.has_sycl_kernel(op_name):
continue
spirv_kernel = program.get_sycl_kernel(op_name)
if spirv_kernel.max_sub_group_size == 16:
cubin_image = spirv_image
break

else: # with nvcc backend
# emit code
Expand Down Expand Up @@ -461,8 +466,8 @@ def add_module(self, operations, compile_options=None, bypass_cache=False):
cutlass.initialize_sycl_context()
arch = "spir64"
host_compile_options = CompilationOptions(
["-std=c++17", "-DCUTLASS_ENABLE_SYCL", "-DSYCL_INTEL_TARGET"],
arch, include_paths, True)
["-std=c++17", "-DCUTLASS_ENABLE_SYCL", "-DSYCL_INTEL_TARGET"],
arch, include_paths, True)

if compile_options is None:
compile_options = CompilationOptions(
Expand Down Expand Up @@ -500,7 +505,8 @@ def add_module(self, operations, compile_options=None, bypass_cache=False):

if self._is_sycl():
q = dpctl.SyclQueue(cutlass.sycl_device())
program = dpctl.program.create_program_from_spirv(q, cubin_image)
program = dpctl.program.create_program_from_spirv(
q, cubin_image)
else:
err, module = cuda.cuModuleLoadData(cubin_image)
if err != cuda.CUresult.CUDA_SUCCESS:
Expand Down
2 changes: 1 addition & 1 deletion python/cutlass/backend/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,7 @@ def epilogue_schedule_name_3x(self):
def procedural_name(self):
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
if self.api == ApiVersion.v3x and self.arch >= 90:
if self.api == ApiVersion.v3x and (self.arch >= 90 or self.arch == 11):
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
return kernel_name_template.format(
p=self.prefix,
Expand Down
31 changes: 18 additions & 13 deletions python/cutlass/library_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from itertools import combinations_with_replacement
import logging
import os

from cuda import __version__
import cutlass_library
Expand All @@ -52,19 +53,23 @@
# Strip any additional information from the CUDA version
_cuda_version = __version__.split("rc")[0]

# Check that Python CUDA version exceeds NVCC version
_nvcc_version = cutlass.nvcc_version()
_cuda_list = _cuda_version.split('.')
_nvcc_list = _nvcc_version.split('.')
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
if int(val_cuda) < int(val_nvcc):
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")

if len(_nvcc_list) > len(_cuda_list):
if len(_nvcc_list) != len(_cuda_list) + 1:
raise Exception(f"Malformatted NVCC version of {_nvcc_version}")
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")
if not os.getenv("CUTLASS_USE_SYCL"):
FMarno marked this conversation as resolved.
Show resolved Hide resolved
# Check that Python CUDA version exceeds NVCC version
_nvcc_version = cutlass.nvcc_version()
_cuda_list = _cuda_version.split('.')
_nvcc_list = _nvcc_version.split('.')
for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list):
if int(val_cuda) < int(val_nvcc):
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")

if len(_nvcc_list) > len(_cuda_list):
if len(_nvcc_list) != len(_cuda_list) + 1:
raise Exception(f"Malformatted NVCC version of {_nvcc_version}")
if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0:
raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}")

else:
_nvcc_version = "2025.0"


class KernelsForDataType:
Expand Down
2 changes: 1 addition & 1 deletion python/cutlass/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def valid_schedule(
kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
if 11 < cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
return (False, "Non-default schedules are only supported on SM90 and beyond")

if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto):
Expand Down
20 changes: 15 additions & 5 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8977,9 +8977,11 @@ def GenerateSM90(manifest, cuda_version):
###################################################################################################

def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version):
# TODO: Add remaining supported configurations
layouts = [
[[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]]
[[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]],
[[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 4]],
[[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.RowMajor, 4]],
[[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 4]],
]

math_instructions = [
Expand All @@ -8995,10 +8997,18 @@ def GeneratePVC_TensorOp_16b_gemm(manifest, cuda_version):

for math_inst in math_instructions:
tile_descriptions = [
TileDescription([math_inst.instruction_shape[0] * 32, math_inst.instruction_shape[1] * 16, math_inst.instruction_shape[2] * 2],
0, [4, 1, 1], math_inst, min_cc, max_cc, [1, 1, 1])
TileDescription([256, 256, 32],
0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]),
TileDescription([128, 512, 32],
0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]),
TileDescription([256, 128, 32],
0, [8, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]),
TileDescription([128, 256, 16],
0, [4, 8, 1], math_inst, min_cc, max_cc, [1, 1, 1]),
TileDescription([8, 128, 32],
0, [1, 4, 1], math_inst, min_cc, max_cc, [1, 1, 1]),
]

data_type = {
"a_type" : math_inst.element_a,
"b_type" : math_inst.element_b,
Expand Down
1 change: 1 addition & 0 deletions test/python/cutlass/gemm/gemm_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def initialize(rows, cols, batch):
return tensor.reshape(rows, cols)


@unittest.skipIf(device_cc() == 11, "Batched GEMM test not supported on PVC")
class GemmF16Batched(unittest.TestCase):
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
M = 512
Expand Down
90 changes: 90 additions & 0 deletions test/python/cutlass/gemm/gemm_bf16_pvc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#################################################################################################
#
# Copyright (c) 2023 - 2025 Codeplay Software Limited. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Low-level functionality tests for GEMM with F16 operands on PVC
"""

from functools import partial
import logging
import unittest

import cutlass
from cutlass.backend.utils.device import device_cc

from utils import LayoutCombination, add_test_gemm


cutlass.set_log_level(logging.WARNING)
cc = 11
dtype = cutlass.DataType.bf16


@unittest.skipIf(device_cc() != cc, 'Device compute capability is insufficient for PVC tests.')
@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
class GemmBF16PVC(unittest.TestCase):
"""
Wrapper class to which tests will be added dynamically in __main__
"""
pass


add_test_pvc_bf16 = partial(add_test_gemm, cls=GemmBF16PVC, cc=11,
element=dtype,
element_C=cutlass.DataType.f32,
element_output=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32,
compilation_modes=["dpcpp"],
opclass=cutlass.OpcodeClass.TensorOp,
cluster_shape=[1, 1, 1])

add_test_pvc_bf16(layouts=LayoutCombination.TTT,
alignments=[2, 2, 4], threadblock_shape=[256, 256, 32],
stages=0, warp_count=[8, 4, 1])

add_test_pvc_bf16(layouts=LayoutCombination.TTT,
alignments=[2, 2, 4], threadblock_shape=[128, 512, 32],
stages=0, warp_count=[4, 8, 1])

add_test_pvc_bf16(layouts=LayoutCombination.TTT,
alignments=[2, 2, 4], threadblock_shape=[256, 128, 32],
stages=0, warp_count=[8, 4, 1])

add_test_pvc_bf16(layouts=LayoutCombination.TTT,
alignments=[2, 2, 4], threadblock_shape=[128, 256, 16],
stages=0, warp_count=[4, 8, 1])

# TODO: Test more configurations as soon as they're supported by the
# CollectiveBuilder

if __name__ == '__main__':
unittest.main()
Loading