Skip to content

Commit

Permalink
feat: Add support for is_causal argument in attention (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored and laikhtewari committed May 24, 2024
1 parent 69ae2e9 commit 588ec96
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 8 deletions.
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,8 +2357,14 @@ def aten_ops_max_pool(
)


def attention_validator(node: Node) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None


@dynamo_tensorrt_converter(
torch.nn.functional.scaled_dot_product_attention,
capability_validator=attention_validator,
)
def tensorrt_scaled_dot_product_attention(
ctx: ConversionContext,
Expand All @@ -2375,6 +2381,7 @@ def tensorrt_scaled_dot_product_attention(
args[0],
args[1],
args[2],
args_bounds_check(args, 5, False),
kwargs.get("scale", None),
)

Expand Down
18 changes: 17 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from typing import Optional, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.types import TRTTensor


Expand All @@ -17,8 +19,11 @@ def scaled_dot_product_attention(
query: TRTTensor,
key: TRTTensor,
value: TRTTensor,
is_causal: bool,
scale: Optional[float],
) -> TRTTensor:
L, S = query.shape[-2], key.shape[-2]

mm = impl.matmul.matrix_multiply(
ctx,
target,
Expand Down Expand Up @@ -46,6 +51,17 @@ def scaled_dot_product_attention(
mm,
scale,
)

if is_causal:
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")

scaled = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
)

softmax = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", scaled, -1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
Expand Down Expand Up @@ -34,6 +35,7 @@ def lower_scaled_dot_product_attention(

if replaced_nodes:
# Repair instances which use the kwargs field (specifically the "scale" kwarg)
# Also repair instances which specified the is_causal or attn_bias fields
for match in replaced_nodes:
attention_node_replaced = None
# Seek the attention operator being replaced
Expand All @@ -43,17 +45,52 @@ def lower_scaled_dot_product_attention(
break

assert attention_node_replaced is not None
assert len(match.replacements) == 1

new_attention_node = match.replacements[0]

assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)

# If the attention operator had keyword-args, copy them to the new node
if attention_node_replaced.kwargs:
assert len(match.replacements) == 1
new_attention_node = match.replacements[0]
assert (
new_attention_node.target
== torch.nn.functional.scaled_dot_product_attention
)
new_attention_node.kwargs = {**attention_node_replaced.kwargs}

# Set default args in new node:
# Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False
new_attention_node.args = new_attention_node.args + (None, 0.0, False)

# The `is_causal` argument was specified
if (
(
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_flash_attention.default
)
and args_bounds_check(attention_node_replaced.args, 4, False)
) or (
(
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
)
and args_bounds_check(attention_node_replaced.args, 6, False)
):
new_attention_node.args = (
new_attention_node.args[:5] + (True,) + new_attention_node.args[6:]
)

# The `attn_bias` argument was specified
if (
attention_node_replaced.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
) and args_bounds_check(attention_node_replaced.args, 3) is not None:
new_attention_node.args = (
new_attention_node.args[:3]
+ attention_node_replaced.args[3]
+ new_attention_node.args[4:]
)

gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

Expand Down
112 changes: 112 additions & 0 deletions tests/py/dynamo/conversion/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import unittest

import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from ..testing_utilities import DECIMALS_OF_AGREEMENT
from .harness import DispatchTestCase


class TestScaledDotProductAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_no_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, None, 0.0, False, scale=None
)

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)

@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, None, 0.0, True, scale=None
)

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(SDPA(), inputs, rtol=1e-2, atol=1e-2, precision=torch.float16)


@unittest.skipIf(
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8,
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater",
)
class TestFlashAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
attn = torch.ops.aten._scaled_dot_product_flash_attention.default(
query,
key,
value,
0,
True, # is_causal
False,
scale=0.25,
)
return attn[0]

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)


class TestEfficientAttention(DispatchTestCase):
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))])
def test_sdpa_causal(self, query_shape, key_shape):
class SDPA(nn.Module):
def forward(self, query, key, value):
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default(
query,
key,
value,
None,
False,
0,
True, # is_causal
scale=0.5,
)
return attn[0]

inputs = []
query = torch.randn(query_shape, dtype=torch.float16)
key = torch.rand(key_shape, dtype=torch.float16)
value = torch.rand(key_shape, dtype=torch.float16)
inputs.extend([query, key, value])
self.run_test(
SDPA(),
inputs,
rtol=1e-2,
atol=1e-2,
precision=torch.float16,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
3 changes: 2 additions & 1 deletion tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down

0 comments on commit 588ec96

Please sign in to comment.