Skip to content

Commit

Permalink
FlashAttention implem and dispatch (#360)
Browse files Browse the repository at this point in the history
* FlashAttention implem WIP

* Fix flashattention forward+backward

* Fix forward/backward for FlashAttention

* Enable tests (more permissive) for f16 backward

* Fix CI

* flashattn only supports Sm75 and above

* Fix CI2

* Disable K=128 when below sm80 for flashattn

Co-authored-by: danthe3rd <danthe3rd>
  • Loading branch information
danthe3rd authored and fmassa committed Aug 10, 2022
1 parent 71c2eab commit 573ed14
Show file tree
Hide file tree
Showing 8 changed files with 434 additions and 61 deletions.
4 changes: 4 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ install_dep_exp: &install_dep_exp
install_repo: &install_repo
- run:
name: Install Repository
no_output_timeout: 30m
command: |
git submodule update --init third_party/flash-attention
$CONDA_PYTHON -m pip install -e .
# Test import.
Expand All @@ -117,7 +119,9 @@ install_repo: &install_repo
install_experimental_repo: &install_experimental_repo
- run:
name: Install Repository
no_output_timeout: 30m
command: |
git submodule update --init third_party/flash-attention
source $BASH_ENV
cd experimental
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = git@github.com:HazyResearch/flash-attention.git
93 changes: 91 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path

import setuptools
import torch
Expand Down Expand Up @@ -44,6 +46,84 @@ def find_version(version_file_path):
raise RuntimeError("Unable to find version string.")


def get_cuda_version(cuda_dir) -> int:
nvcc_bin = "nvcc" if cuda_dir is None else cuda_dir + "/bin/nvcc"
raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = int(release[0])
bare_metal_minor = int(release[1][0])

assert bare_metal_minor < 100
return bare_metal_major * 100 + bare_metal_minor


def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version > 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
elif cuda_version >= 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0"
else:
return []

archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST)
nvcc_archs_flags = []
for arch in archs_list.split(";"):
assert len(arch) >= 3, f"Invalid sm version: {arch}"

num = 10 * int(arch[0]) + int(arch[2])
# Need at least 7.0
if num < 75:
continue
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}")
if arch.endswith("+PTX"):
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=compute_{num}")
if not nvcc_archs_flags:
return []

this_dir = os.path.dirname(os.path.abspath(__file__))
flash_root = os.path.join(this_dir, "third_party", "flash-attention")
return [
CUDAExtension(
name="xformers._C_flashattention",
sources=[
os.path.join(this_dir, "third_party", "flash-attention", path)
for path in [
"csrc/flash_attn/fmha_api.cpp",
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
]
],
extra_compile_args={
**extra_compile_args,
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
+ nvcc_archs_flags,
},
include_dirs=[
Path(flash_root) / "csrc" / "flash_attn",
Path(flash_root) / "csrc" / "flash_attn" / "src",
# Path(flash_root) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
Path(this_dir) / "third_party" / "cutlass" / "include",
],
)
]


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
Expand All @@ -57,6 +137,7 @@ def get_extensions():
)

sources = main_file + source_cpu

source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
Expand All @@ -74,6 +155,7 @@ def get_extensions():
extra_compile_args["cxx"].append("-fopenmp")

include_dirs = [extensions_dir]
ext_modules = []

if (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv(
"FORCE_CUDA", "0"
Expand All @@ -86,19 +168,26 @@ def get_extensions():
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(" ")
cuda_version = get_cuda_version(CUDA_HOME)
if cuda_version >= 1102:
nvcc_flags += ["--threads", "4"]
extra_compile_args["nvcc"] = nvcc_flags
if cuda_version >= 1100:
ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

sources = [os.path.join(extensions_dir, s) for s in sources]

ext_modules = [
ext_modules.append(
extension(
"xformers._C",
sorted(sources),
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
)

return ext_modules

Expand Down
100 changes: 49 additions & 51 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def assert_allclose(
out: torch.Tensor, ref: torch.Tensor, msg: str = "failed", **kwargs
) -> None:
assert torch.allclose(
out, ref, **kwargs
), f"{msg}: max_diff={(out - ref).abs().max()} / atol={kwargs.get('atol', 1e-8)}"


def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0):
q = q.float()
k = k.float()
Expand Down Expand Up @@ -43,6 +51,7 @@ def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0):
[
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionGenericForwardOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
],
)
def test_memory_efficient_attention(
Expand All @@ -55,14 +64,6 @@ def test_memory_efficient_attention(
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or dtype not in op.SUPPORTED_DTYPES
):
return # Or `pytest.xfail` ?

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
Expand All @@ -74,12 +75,19 @@ def test_memory_efficient_attention(
)
attn_bias = attn_bias.expand(batch_size, q_len, kv_len)

if not op.supports(
xformers.ops.AttentionOpDispatch.from_arguments(
query=query, key=key, value=value, attn_bias=attn_bias
)
):
pytest.skip("unsupported configuration")

out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=op
).float()
ref = ref_attention(query, key, value, attn_bias)

assert torch.allclose(out, ref, atol=2e-4)
assert_allclose(out, ref, atol=op.FORWARD_ERROR_ATOL)


@pytest.mark.parametrize("k_len", [5, 6, 32])
Expand All @@ -97,7 +105,7 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len):
# this should be equivalent to the average over value
ref = value.mean(1, keepdim=True).expand_as(query)

assert torch.allclose(out, ref, atol=1e-5)
assert_allclose(out, ref, atol=1e-5)


@pytest.mark.parametrize("k_len", [5, 6, 32])
Expand All @@ -122,24 +130,24 @@ def test_logsumexp(
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or dtype not in op.SUPPORTED_DTYPES
):
return

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale

if not op.supports(
xformers.ops.AttentionOpDispatch.from_arguments(
query=query, key=key, value=value
)
):
pytest.skip("unsupported configuration")

_, lse, _, _ = op.FORWARD_OPERATOR(query, key, value, True, None, 0.0)
ref_lse = (
(query.float() / k_len**0.5) @ key.float().transpose(-2, -1)
).logsumexp(-1)

assert torch.allclose(lse, ref_lse, atol=2e-4)
assert_allclose(lse, ref_lse, atol=2e-4)


@pytest.mark.parametrize("use_attn_bias", [False, True])
Expand All @@ -149,12 +157,13 @@ def test_logsumexp(
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 64, 128])
@pytest.mark.parametrize("q_len", [2, 3, 5, 32, 128])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", [torch.float])
@pytest.mark.parametrize("dtype", [torch.float, torch.half])
@pytest.mark.parametrize(
"op",
[
xformers.ops.MemoryEfficientAttentionOp,
xformers.ops.MemoryEfficientAttentionGenericForwardOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
],
)
def test_memory_efficient_attention_backward(
Expand All @@ -168,14 +177,6 @@ def test_memory_efficient_attention_backward(
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
if (
device not in op.SUPPORTED_DEVICES
or k_len > op.SUPPORTED_MAX_K
or (use_attn_bias and not op.SUPPORTS_ATTN_BIAS)
or dtype not in op.SUPPORTED_DTYPES
):
return

scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
Expand All @@ -188,6 +189,13 @@ def test_memory_efficient_attention_backward(
)
attn_bias = attn_bias.expand(batch_size, q_len, kv_len)

if not op.supports(
xformers.ops.AttentionOpDispatch.from_arguments(
query=query, key=key, value=value, attn_bias=attn_bias
)
):
pytest.skip("unsupported configuration")

query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
Expand All @@ -211,6 +219,10 @@ def test_memory_efficient_attention_backward(
ref.backward(grad_out)

atol = 2e-4 + 2e-6 * k_len * kv_len * math.sqrt(batch_size) * math.sqrt(q_len)
rtol = 1e-8
if dtype is torch.half:
atol = 3e-2
rtol = 1e-2

# (for mypy)
assert isinstance(query.grad, torch.Tensor)
Expand All @@ -222,9 +234,7 @@ def test_memory_efficient_attention_backward(
("key", grad_k, key.grad),
("value", grad_v, value.grad),
]:
assert torch.allclose(
calc_grad, ref_grad, atol=atol
), f"{name} doesn't match (max_diff={(calc_grad - ref_grad).abs().max()} > {atol}) - dtype={dtype}"
assert_allclose(calc_grad, ref_grad, name, atol=atol, rtol=rtol)


def _vec_binom_test(x, n, p):
Expand Down Expand Up @@ -276,15 +286,15 @@ def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed):
torch.manual_seed(seed)
out2 = xformers.ops.memory_efficient_attention(query, key, value, attn_bias, p)

assert torch.allclose(out, out2)
assert_allclose(out, out2)

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)

ref = ref_attention(query, key, value, attn_bias, mask, p)
assert torch.allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 0.0001
Expand Down Expand Up @@ -348,15 +358,9 @@ def test_dropout_backward(device, q_len, kv_len, batch_size, k_len, p):
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 5e-4 if device == "cuda" else 6e-4
assert torch.allclose(
grad_q, query.grad, atol=atol
), f"grad_q doesn't match {(grad_q - query.grad).abs().max()}"
assert torch.allclose(
grad_k, key.grad, atol=atol
), f"grad_k doesn't match {(grad_k - key.grad).abs().max()}"
assert torch.allclose(
grad_v, value.grad, atol=atol
), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}"
assert_allclose(grad_q, query.grad, "grad_q", atol=atol)
assert_allclose(grad_k, key.grad, "grad_k", atol=atol)
assert_allclose(grad_v, value.grad, "grad_v", atol=atol)


@pytest.mark.parametrize("k_len", [32])
Expand All @@ -380,7 +384,7 @@ def test_memory_efficient_attention_full_block_masked(
out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias)
ref = ref_attention(query, key, value, attn_bias)

assert torch.allclose(out, ref, atol=2e-4)
assert_allclose(out, ref, atol=2e-4)

query.requires_grad_(True)
key.requires_grad_(True)
Expand All @@ -406,12 +410,6 @@ def test_memory_efficient_attention_full_block_masked(
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 5e-4 if device == "cuda" else 6e-4
assert torch.allclose(
grad_q, query.grad, atol=atol
), f"grad_q doesn't match {(grad_q - query.grad).abs().max()}"
assert torch.allclose(
grad_k, key.grad, atol=atol
), f"grad_k doesn't match {(grad_k - key.grad).abs().max()}"
assert torch.allclose(
grad_v, value.grad, atol=atol
), f"grad_v doesn't match {(grad_v - value.grad).abs().max()}"
assert_allclose(grad_q, query.grad, "grad_q", atol=atol)
assert_allclose(grad_k, key.grad, "grad_k", atol=atol)
assert_allclose(grad_v, value.grad, "grad_v", atol=atol)
1 change: 1 addition & 0 deletions third_party/flash-attention
Submodule flash-attention added at 5b838a
Empty file added xformers/_C_flashattention.pyi
Empty file.
Loading

0 comments on commit 573ed14

Please sign in to comment.