Skip to content

Commit

Permalink
Autotuner for int mm Triton kernels (pytorch#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Mar 21, 2024
1 parent d0f1aa2 commit 530f71b
Show file tree
Hide file tree
Showing 17 changed files with 1,201 additions and 65 deletions.
59 changes: 59 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
merge_base_with = "origin/main"

[[linter]]
code = 'FLAKE8'
include_patterns = ['**/*.py']
exclude_patterns = [
'third-party/**',
'**/third-party/**',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'flake8_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt',
]

# Black + usort
[[linter]]
code = 'UFMT'
include_patterns = [
'**/*.py',
'**/*.pyi',
]
exclude_patterns = [
'third-party/**',
'**/third-party/**',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'ufmt_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'--requirement=requirements-lintrunner.txt',
]
is_formatter = true
96 changes: 96 additions & 0 deletions benchmarks/intmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import csv
import itertools
import math
import pathlib

import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul

torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128

dtype = torch.float16
device = "cuda"


def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
for _ in range(warmup):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(iters):
f(*args, **kwargs)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / float(iters)


@torch.compile(mode="max-autotune")
def compiled_mm(x, w):
return torch.mm(x, w)


@torch.compile(mode="max-autotune")
def compiled_int_mm(x, w):
return torch._int_mm(x, w)


def run_int_mm_benchmark(x, w, b):
fp_time = benchmark_in_ms(10, 100, torch.mm, x, w)
x_int = x.to(dtype=torch.int8)
w_int = w.to(dtype=torch.int8)
int_mm_time = benchmark_in_ms(10, 100, int_matmul, x_int, w_int)
return fp_time, int_mm_time


def run_int_scaled_mm_benchmark(x, w, b):
scales = x.sum(-1, keepdim=True)
fp_time = benchmark_in_ms(10, 100, lambda x, w, s: torch.mm(x, w) * s, x, w, scales)
x_int = x.to(dtype=torch.int8)
w_int = w.to(dtype=torch.int8)
int_scaled_mm_time = benchmark_in_ms(
10, 100, int_scaled_matmul, x_int, w_int, scales
)
return fp_time, int_scaled_mm_time


def run_benchmarks(shapes):
print("fn,m,k,n,fp_time,int_mm_time,ratio")
positives = []
dtype = torch.bfloat16
device = "cuda"
for fn, (m, k, n) in itertools.product(
[run_int_mm_benchmark, run_int_scaled_mm_benchmark], shapes
):
x = torch.randn(m, k, dtype=dtype, device=device)
w = torch.randn(n, k, dtype=dtype, device=device).t()
b = torch.randn(m, n, dtype=dtype, device=device)

fp_time, int_mm_time = fn(x, w, b)
ratio = fp_time / int_mm_time
result = ",".join(map(str, [fn, m, k, n, fp_time, int_mm_time, ratio]))
print(result)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="integer matmul benchmarks")
parser.add_argument("file_path", type=str, help="Path to csv file with shapes")
args = parser.parse_args()
# Access the file path provided as an argument
file_path = args.file_path
file_path = pathlib.Path(file_path)
assert file_path.is_file()

# Format is (m, k, n)
shapes = list(csv.reader(open(file_path, "r")))[1:]
# Turn into list of int tuples
shapes = list(map(lambda x: tuple(map(int, x)), shapes))

run_benchmarks(shapes)
127 changes: 127 additions & 0 deletions benchmarks/intmm_shapes.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
m,k,n
1024,1024,2304
1024,1024,4608
1024,8192,2304
1024,8192,4608
1152,1024,2048
1152,2048,16384
1152,2048,2048
1152,3072,2048
1152,4096,2048
1152,8192,2048
1,2048,1024
1,2048,2048
1,2048,4096
144,2048,16384
144,2048,2048
144,4096,2048
144,8192,2048
1472,1024,154
1472,1024,308
1472,2048,154
1472,2048,308
1472,512,154
1472,512,308
1,512,2048
154,1472,1024
154,1472,2048
154,1472,512
18432,1024,512
18432,1536,512
18432,2048,512
18432,512,4096
18432,512,512
2048,1024,1
2048,1024,2
2048,16384,1152
2048,16384,144
2048,16384,288
2048,16384,576
2048,2048,1
2048,2048,1152
2048,2048,144
2048,2048,2
2048,2048,288
2048,2048,576
2048,4096,1
2048,4096,2
2048,512,18432
2048,512,9216
2,2048,1024
2,2048,2048
2,2048,4096
2304,1024,1024
2304,1024,8192
2304,1536,1024
2304,2048,1024
2304,3072,1024
2304,4096,1024
2304,512,1024
231,4096,1024
231,4096,2048
231,4096,512
231,768,1024
231,768,2048
231,768,512
2,512,2048
288,2048,16384
288,2048,2048
288,4096,2048
288,8192,2048
308,1472,1024
308,1472,2048
308,1472,512
4096,1024,2304
4096,1024,231
4096,1024,4608
4096,1024,462
4096,2048,231
4096,2048,462
4096,512,231
4096,512,462
4608,1024,1024
4608,1024,8192
4608,1536,1024
4608,2048,1024
4608,3072,1024
4608,4096,1024
4608,512,1024
462,4096,1024
462,4096,2048
462,4096,512
462,768,1024
462,768,2048
462,768,512
512,2048,1
512,2048,2
512,4096,18432
512,4096,9216
512,512,18432
512,512,9216
576,1024,2048
576,2048,16384
576,2048,2048
576,3072,2048
576,4096,2048
576,8192,2048
768,1024,231
768,1024,462
768,2048,231
768,2048,462
768,512,231
768,512,462
8192,2048,1152
8192,2048,144
8192,2048,288
8192,2048,576
9216,1024,512
9216,1536,512
9216,2048,512
9216,512,4096
9216,512,512
32768,3072,768
32768,768,2304
32768,768,3072
32768,768,768
39200,768,2304
39200,768,768
16 changes: 16 additions & 0 deletions benchmarks/print_config_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torchao

from torchao.kernel import autotuner

configs = autotuner._load_best_configs()

print("m,k,n")
for k, v in configs.items():
a_shape = k[1]
b_shape = k[4]
M, K0 = a_shape
K1, N = b_shape

assert K0 == K1

print(f"{M},{K0},{N}")
7 changes: 7 additions & 0 deletions benchmarks/sam_vit_b_shapes.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
m,k,n
32768,3072,768
32768,768,2304
32768,768,3072
32768,768,768
39200,768,2304
39200,768,768
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
expecttest
packaging
parameterized
packaging
22 changes: 22 additions & 0 deletions requirements-lintrunner.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Lintrunner itself
lintrunner==0.11.0
lintrunner-adapters==0.11.0

# Flake 8 and its dependencies
flake8==6.0.0
flake8-breakpoint==1.1.0
flake8-bugbear==23.6.5
flake8-comprehensions==3.12.0
flake8-pyi==23.5.0
mccabe==0.7.0
pycodestyle==2.10.0
torchfix==0.1.1

# UFMT
black==24.2.0
ufmt==2.5.1
usort==1.0.5

# Other linters
clang-format==12.0.1
cmakelint==1.4.1
27 changes: 17 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,35 @@

import os
from datetime import datetime
from setuptools import setup, find_packages
current_date = datetime.now().strftime('%Y.%m.%d')

from setuptools import find_packages, setup

current_date = datetime.now().strftime("%Y.%m.%d")


def read_requirements(file_path):
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
return file.read().splitlines()


# Determine the package name based on the presence of an environment variable
package_name = 'torchao-nightly' if os.environ.get('TORCHAO_NIGHTLY') else 'torchao'
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"

# Version is year.month.date if using nightlies
version = current_date if package_name == 'torchao-nightly' else '0.0.3'
version = current_date if package_name == "torchao-nightly" else "0.0.3"


setup(
name=package_name,
version=version,
packages=find_packages(),
install_requires=read_requirements('requirements.txt'),
description='Package for applying ao techniques to GPU models',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url='https://github.com/pytorch-labs/ao',
include_package_data=True,
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
install_requires=read_requirements("requirements.txt"),
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch-labs/ao",
)
Loading

0 comments on commit 530f71b

Please sign in to comment.