Skip to content

Commit

Permalink
update getting orig graph
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Nov 2, 2023
1 parent 30aaa8c commit 1d5b9d0
Showing 1 changed file with 6 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
):
"""Constructs the original and replacement functions for efficient attention"""

# Empty boilerplate function taking in three Tensors and returning one
def boilerplate(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
...

# Trace boilerplate function and extract placeholder and output nodes
orig = torch.fx.symbolic_trace(boilerplate)
q, k, v = get_tensor_placeholders(orig)
output = [node for node in orig.graph.nodes if node.op == "output"][0]

# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
# function and extract only the first element
with orig.graph.inserting_before(output):
att = orig.graph.call_function(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
args=(q, k, v, None, False),
# Original graph
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q, k, v, None, False
)
out = orig.graph.call_function(
operator.getitem,
args=(att, 0),
)

# Assign the output of the graph to be the single getitem output
output.args = (out,)

orig.graph.lint()
orig.recompile()
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
Expand Down

0 comments on commit 1d5b9d0

Please sign in to comment.