Skip to content

Commit

Permalink
Update RemoveLocalScalarDenseOpsTransform to tag scalar tensors as …
Browse files Browse the repository at this point in the history
…well (#6069)

Summary:
Pull Request resolved: #6069

## Context

See the new docstrings added to `remove_local_scalar_dense_ops` for more details on what the pass is trying to achieve.

The goal is to mark tensors that are consumed as scalars via `tensor[0].item()` as "scalar tensors" that will be represented as a `SymInt` object in the vulkan delegate instead of a regular `Tensor` object.

This diff also adds an `__init__.py` file to the `_passes` folder to make it easier to include Vulkan passes from one place.
ghstack-source-id: 247163956
exported-using-ghexport

Reviewed By: jorgep31415

Differential Revision: D64139867

fbshipit-source-id: 88ba420e107654d7eadb2cbca78a3750a51f74b0
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Oct 10, 2024
1 parent 7bfab21 commit 1a0c2c7
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 20 deletions.
2 changes: 1 addition & 1 deletion backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:remove_local_scalar_dense",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir:graph_module",
"//executorch/exir/_serialize:_bindings",
"//executorch/exir/_serialize:lib",
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,16 @@ runtime.python_library(
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "vulkan_passes",
srcs = [
"__init__.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
":remove_local_scalar_dense",
]
)
7 changes: 7 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)

__all__ = [
"RemoveLocalScalarDenseOpsTransform",
]
98 changes: 82 additions & 16 deletions backends/vulkan/_passes/remove_local_scalar_dense_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,95 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from torch._subclasses.fake_tensor import FakeTensor


def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool:
"""
Converting a tensor to a scalar via tensor[0].item() creates a index_select +
local_scalar_dense pattern in the graph. Check if a node is the start of this pattern.
"""
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.select_copy.int
and len(node.users) == 1
):
user = list(node.users.keys())[0]
return user.target == torch.ops.aten._local_scalar_dense.default

return False


def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None:
"""
A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar
value instead of a Tensor object. The criteria for identifying a tensor as a scalar
tensor are as follows:
1. The tensor has only 1 element
2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which
creates a index_select + local_scalar_dense pattern in the graph
If any of these criteria are fulfilled, then tag the node for the tensor to mark it
so that it is added as a scalar value during serialization.
"""
tensor_val = node.meta["val"]
if not isinstance(tensor_val, FakeTensor):
return

# Scalar tensors must have only one element
if tensor_val.numel() != 1:
return

for user in node.users:
if node_is_local_scalar_dense_chain(user):
node.meta["vkdg_is_scalar_tensor"] = True


def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None:
"""
Remove the index_select + local_scalar_dense pattern in the graph in favor of passing
the original scalar tensor directly.
"""
replace_node = node.args[0]
assert isinstance(replace_node, torch.fx.Node)
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
# pyre-ignore
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]
assert isinstance(replace_node, torch.fx.Node)
assert replace_node.meta.get("vkdg_is_scalar_tensor", True)

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)


def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Remove local_scalar_dense op nodes and replace uses with parent node, or the
original scalar tensor.
The purpose of this pass is twofold:
1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria)
2. Remove the index_select + local_scalar_dense pattern in the graph in favor of
passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`)
This makes it easier to deal with scalar tensors in the Vulkan backend. In particular,
it allows serializing scalar tensors as SymInt objects instead of Tensor objects.
Because scalar tensors are often used to inform tensor shapes, their values need to
be easily accessed by the CPU during resizing logic, while also being able to reflect
updates to their value in any GPU shaders that reference them.
"""
target_op = torch.ops.aten._local_scalar_dense.default
for node in graph.nodes:
tag_node_if_scalar_tensor(node)

if node.op == "call_function" and node.target == target_op:
replace_node = node.args[0]
# If the argument to the local_scalar_dense op is a select op with only
# one user, and the argument to the select op is a tensor with only one
# element (i.e. a scalar tensor), then replace the entire pattern with the
# scalar tensor.
if (
replace_node.op == "call_function"
and replace_node.target == exir_ops.edge.aten.select_copy.int
):
if replace_node.args[0].meta["val"].numel() == 1:
replace_node = replace_node.args[0]

with graph.inserting_after(node):
node.replace_all_uses_with(replace_node)
remove_local_scalar_dense_chain(graph, node)

graph.eliminate_dead_code()
return graph
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand Down

0 comments on commit 1a0c2c7

Please sign in to comment.