Skip to content

Commit

Permalink
Add index.Tensor and aten.logical_not (#3221)
Browse files Browse the repository at this point in the history
Summary:
Add missing llama ops for MPS delegate:
- `index.Tensor`
- `logical_not`

`index.put` works correctly for generating 1 token, but gives incorrect results on 2nd token. This remains disabled.

Summary of changes:
- Adds missing llama2 ops
- Adds support for launching Metal kernels instead of MPSGraph ops (if MPSGraph doesn't have the support)

cc cccclai , shoumikhin

Pull Request resolved: #3221

Reviewed By: shoumikhin

Differential Revision: D56447710

Pulled By: cccclai

fbshipit-source-id: 778a485df5e67d1afd006b42f07b69c8a3961223
  • Loading branch information
DenisVieriu97 authored and facebook-github-bot committed Apr 24, 2024
1 parent e9d7868 commit 02a6b66
Show file tree
Hide file tree
Showing 15 changed files with 446 additions and 10 deletions.
12 changes: 12 additions & 0 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
MPSGraph,
MPSTensor,
OpType,
)

from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
Expand Down Expand Up @@ -65,6 +66,7 @@ def preprocess(
input_ids=[],
output_ids=[],
constant_ids=[],
graph_type=OpType.mps_graph,
)

convert_model_to_fp16 = True
Expand Down Expand Up @@ -111,6 +113,16 @@ def handle_call_function(
mps_graph: MPSGraph,
) -> None:
logging.info(f"Visiting: {node}, {node.target.__name__}")

if (
"delegation_tag" in node.meta
and "metal_kernel" in node.meta["delegation_tag"]
):
logging.info(
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
)
mps_graph.graph_type = OpType.metal_kernel

if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(node, mps_graph)
else:
Expand Down
77 changes: 76 additions & 1 deletion backends/apple/mps/operators/indexing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Provided subject to the LICENSE file in the top level directory.
#

from typing import cast
from typing import cast, List

import torch
from executorch.backends.apple.mps.operators.node_visitor import (
Expand All @@ -13,9 +13,12 @@
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
MPSEmbedding,
MPSGraph,
MPSIndexPut,
MPSIndexSelect,
MPSIndexTensor,
)
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
from executorch.backends.transforms import get_shape
from executorch.exir.sym_util import eval_expr


Expand All @@ -40,6 +43,78 @@ def define_node(
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class IndexTensorVisitor(NodeVisitor):
target = "aten.index.Tensor"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
tensors = cast(List[torch.fx.Node], node.args[1])
for tensor in tensors:
mps_node.mpsnode_union.indices_id.append(
self.define_tensor(tensor, mps_graph)
)

mps_graph.mps_nodes.append(mps_node)


# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
# are wrong when using Index put. Disabling it for now.
@register_node_visitor
class IndexPutVisitor(NodeVisitor):
# target = "aten.index_put.default"
target = "disabled"

def __init__(self, *args) -> None:
super().__init__(*args)

def infer_sizes(self, a: List[int], b: List[int]):
dimsA = len(a)
dimsB = len(b)
ndim = dimsA if dimsA > dimsB else dimsB
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else -1
sizeB = b[dimB] if dimB >= 0 else -1
expandedSizes[i] = sizeA if sizeB == -1 else sizeB

return expandedSizes

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut)
updates_shape = get_shape(node.args[2])
input_shape = get_shape(node.args[0])
new_shape = []
if len(updates_shape) != 1 and len(updates_shape) != len(input_shape):
new_shape = self.infer_sizes(input_shape, updates_shape)
mps_node.mpsnode_union.values_shape = new_shape

tensors = cast(List[torch.fx.Node], node.args[1])
for tensor in tensors:
mps_node.mpsnode_union.indices_id.append(
self.define_tensor(tensor, mps_graph)
)

mps_node.mpsnode_union.values_id = self.define_tensor(
get_input_node(node, 2), mps_graph
)
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class EmbeddingVisitor(NodeVisitor):
target = "aten.embedding.default"
Expand Down
3 changes: 3 additions & 0 deletions backends/apple/mps/operators/unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MPSLog,
MPSLog10,
MPSLog2,
MPSLogicalNot,
MPSNeg,
MPSReciprocal,
MPSRound,
Expand Down Expand Up @@ -79,6 +80,7 @@ class UnaryOpVisitor(NodeVisitor):
"aten.isnan.default",
"aten.isinf.default",
"aten.round.default",
"aten.logical_not.default",
]

def __init__(self, *args) -> None:
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, *args) -> None:
exir_ops.edge.aten.isnan.default: MPSIsnan,
exir_ops.edge.aten.isinf.default: MPSIsinf,
exir_ops.edge.aten.round.default: MPSRound,
exir_ops.edge.aten.logical_not.default: MPSLogicalNot,
}

def define_node(
Expand Down
50 changes: 48 additions & 2 deletions backends/apple/mps/partition/mps_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
#

import logging
from typing import Any, Dict, List, Union
from typing import Any, cast, Dict, List, Union

import torch
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
from executorch.backends.transforms import get_shape
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
Expand All @@ -20,6 +21,7 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupportBase
Expand All @@ -28,6 +30,13 @@
logging.basicConfig(level=logging.DEBUG, format=FORMAT)


# ops implemented as Metal kernels.
METAL_KERNELS = [
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
]


class MPSOperatorSupport(OperatorSupportBase):
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
self.node_visitors = get_node_visitors(edge_program)
Expand Down Expand Up @@ -65,10 +74,47 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
op_support=self.supported_ops,
)

def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
num_indices = 0
tensors = cast(List[torch.fx.Node], node.args[1])
input = cast(torch.fx.Node, node.args[0])
for t in tensors:
if t is not None:
num_indices += 1
# Can dispatch to MPSGraph if the length of the slices is equal
# to the number of dimensions of the sliced tensors, or only one
# slice is present. All other cases will fallback to a Metal kernel.
if num_indices == len(get_shape(input)) or num_indices == 1:
return True

return False

def use_metal_kernel(self, node: torch.fx.Node):
if node.target in METAL_KERNELS:
if (
node.target == exir_ops.edge.aten.index.Tensor
or node.target == exir_ops.edge.aten.index_put.default
):
if not self.mps_graph_advanced_indexing_support(node):
return True
return False

def tag_nodes(self, partitions: List[Partition]) -> None:
for partition in partitions:
for node in partition.nodes:
crt_partition_counter = 0
for node in sorted(partition.nodes):
delegation_tag = f"mps_{partition.id}"
if self.use_metal_kernel(node):
logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
# Partition the Metal kernel into a separate partition
crt_partition_counter += 1
delegation_tag = (
f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
)
crt_partition_counter += 1
else:
delegation_tag = f"{delegation_tag}_{crt_partition_counter}"

node.meta["delegation_tag"] = delegation_tag
self.partition_tags[delegation_tag] = self.delegation_spec

Expand Down
23 changes: 23 additions & 0 deletions backends/apple/mps/runtime/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@

#pragma once

// Obj-C headers
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>

// Runtime headers
#include <executorch/runtime/backend/interface.h>

// MPS headers
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <unordered_map>
#include <vector>

#define MB(x) (x * 1048576UL)

namespace torch {
Expand All @@ -25,6 +34,11 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_14_0_PLUS,
};

enum class LibraryType : uint32_t {
INDEXING_KERNELS = 0,
MAX = INDEXING_KERNELS,
};

class MPSDevice {
public:
/**
Expand Down Expand Up @@ -53,9 +67,18 @@ class MPSDevice {

~MPSDevice();

/**
* Compile a PSO for a given library type.
* Once compiled, the library and PSOs are cached.
*/
Error compilePSO(LibraryType libraryType, const char* kernelName);
Error compileLibrary(LibraryType);

private:
static MPSDevice* _device;
id<MTLDevice> _mtl_device;
std::unordered_map<LibraryType, id<MTLLibrary>> _m_library_cache;
std::unordered_map<std::string, id<MTLComputePipelineState>> _m_pso_cache;
MPSDevice();
};

Expand Down
65 changes: 65 additions & 0 deletions backends/apple/mps/runtime/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
static std::unique_ptr<MPSDevice> mps_device;
static std::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
#if defined(__MAC_13_0)
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
#endif

ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}

MPSDevice::~MPSDevice() {
[_mtl_device release];
_mtl_device = nil;
Expand Down Expand Up @@ -79,6 +93,57 @@
}
}

const char* getLibraryCString(LibraryType libraryType) {
switch (libraryType) {
case LibraryType::INDEXING_KERNELS:
return "TODO";
default:
ET_CHECK_MSG(false, "Unhandled library type!");
}
}

Error
MPSDevice::compileLibrary(LibraryType libraryType) {
Error err = Error::Ok;
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
id<MTLLibrary> lib =
[_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
encoding:NSASCIIStringEncoding]
options:options
error:&error];

ET_CHECK_OR_RETURN_ERROR(
lib != nil,
Internal,
"Failed to create indexing library, error: %s", [[error description] UTF8String]
);

_m_library_cache[libraryType] = lib;
return err;
}

Error
MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
Error err = Error::Ok;
if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
ET_LOG(Debug, "Compiling library type: %d", libraryType);
err = compileLibrary(libraryType);
ET_CHECK_OR_RETURN_ERROR(
err == Error::Ok,
Internal,
"An error occured occured while compiling library %d", libraryType
);
}
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
ET_LOG(Debug, "Compiling kernel: %s", kernelName);
// err = compilePSO(libraryType, kernelName);
}
return err;
}

bool isMacOS13OrNewer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}
Expand Down
Loading

0 comments on commit 02a6b66

Please sign in to comment.