diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py index 944b0788b0..6984a70254 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py @@ -5,7 +5,6 @@ import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, - get_tensor_placeholders, ) logger = logging.getLogger(__name__) @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> ( ): """Constructs the original and replacement functions for efficient attention""" - # Empty boilerplate function taking in three Tensors and returning one - def boilerplate( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - ... - - # Trace boilerplate function and extract placeholder and output nodes - orig = torch.fx.symbolic_trace(boilerplate) - q, k, v = get_tensor_placeholders(orig) - output = [node for node in orig.graph.nodes if node.op == "output"][0] - - # Graph types to replace are those which use the _scaled_dot_product_efficient_attention - # function and extract only the first element - with orig.graph.inserting_before(output): - att = orig.graph.call_function( - torch.ops.aten._scaled_dot_product_efficient_attention.default, - args=(q, k, v, None, False), + # Original graph + def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, None, False ) - out = orig.graph.call_function( - operator.getitem, - args=(att, 0), - ) - - # Assign the output of the graph to be the single getitem output - output.args = (out,) - - orig.graph.lint() - orig.recompile() + out = operator.getitem(outputs, 0) + return out # Replacement graph consists of the functional version of scaled_dot_product_attention def replacement(