From 02a6b66c2cfea96787e41aba8b12f343bd970322 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Tue, 23 Apr 2024 18:21:53 -0700 Subject: [PATCH] Add index.Tensor and aten.logical_not (#3221) Summary: Add missing llama ops for MPS delegate: - `index.Tensor` - `logical_not` `index.put` works correctly for generating 1 token, but gives incorrect results on 2nd token. This remains disabled. Summary of changes: - Adds missing llama2 ops - Adds support for launching Metal kernels instead of MPSGraph ops (if MPSGraph doesn't have the support) cc cccclai , shoumikhin Pull Request resolved: https://github.com/pytorch/executorch/pull/3221 Reviewed By: shoumikhin Differential Revision: D56447710 Pulled By: cccclai fbshipit-source-id: 778a485df5e67d1afd006b42f07b69c8a3961223 --- backends/apple/mps/mps_preprocess.py | 12 +++ backends/apple/mps/operators/indexing_ops.py | 77 ++++++++++++++- backends/apple/mps/operators/unary_ops.py | 3 + .../apple/mps/partition/mps_partitioner.py | 50 +++++++++- backends/apple/mps/runtime/MPSDevice.h | 23 +++++ backends/apple/mps/runtime/MPSDevice.mm | 65 +++++++++++++ backends/apple/mps/runtime/MPSGraphBuilder.h | 7 ++ backends/apple/mps/runtime/MPSGraphBuilder.mm | 53 +++++++++- .../mps/runtime/operations/IndexingOps.mm | 96 +++++++++++++++++++ .../mps/runtime/operations/OperationUtils.mm | 8 ++ .../apple/mps/runtime/operations/ShapeOps.mm | 8 +- .../apple/mps/runtime/operations/UnaryOps.mm | 1 + .../mps/serialization/mps_graph_schema.py | 26 +++++ backends/apple/mps/serialization/schema.fbs | 26 +++++ backends/apple/mps/targets.bzl | 1 + 15 files changed, 446 insertions(+), 10 deletions(-) diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 0e543d7e07..bb828ed0f9 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -18,6 +18,7 @@ from executorch.backends.apple.mps.serialization.mps_graph_schema import ( MPSGraph, MPSTensor, + OpType, ) from executorch.backends.apple.mps.serialization.mps_graph_serialize import ( @@ -65,6 +66,7 @@ def preprocess( input_ids=[], output_ids=[], constant_ids=[], + graph_type=OpType.mps_graph, ) convert_model_to_fp16 = True @@ -111,6 +113,16 @@ def handle_call_function( mps_graph: MPSGraph, ) -> None: logging.info(f"Visiting: {node}, {node.target.__name__}") + + if ( + "delegation_tag" in node.meta + and "metal_kernel" in node.meta["delegation_tag"] + ): + logging.info( + f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!" + ) + mps_graph.graph_type = OpType.metal_kernel + if node.target.__name__ in node_visitors: node_visitors[node.target.__name__].define_node(node, mps_graph) else: diff --git a/backends/apple/mps/operators/indexing_ops.py b/backends/apple/mps/operators/indexing_ops.py index f2c9dc6aea..690549973a 100644 --- a/backends/apple/mps/operators/indexing_ops.py +++ b/backends/apple/mps/operators/indexing_ops.py @@ -3,7 +3,7 @@ # Provided subject to the LICENSE file in the top level directory. # -from typing import cast +from typing import cast, List import torch from executorch.backends.apple.mps.operators.node_visitor import ( @@ -13,9 +13,12 @@ from executorch.backends.apple.mps.serialization.mps_graph_schema import ( MPSEmbedding, MPSGraph, + MPSIndexPut, MPSIndexSelect, + MPSIndexTensor, ) from executorch.backends.apple.mps.utils.mps_utils import get_input_node +from executorch.backends.transforms import get_shape from executorch.exir.sym_util import eval_expr @@ -40,6 +43,78 @@ def define_node( mps_graph.mps_nodes.append(mps_node) +@register_node_visitor +class IndexTensorVisitor(NodeVisitor): + target = "aten.index.Tensor" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor) + tensors = cast(List[torch.fx.Node], node.args[1]) + for tensor in tensors: + mps_node.mpsnode_union.indices_id.append( + self.define_tensor(tensor, mps_graph) + ) + + mps_graph.mps_nodes.append(mps_node) + + +# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens +# are wrong when using Index put. Disabling it for now. +@register_node_visitor +class IndexPutVisitor(NodeVisitor): + # target = "aten.index_put.default" + target = "disabled" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def infer_sizes(self, a: List[int], b: List[int]): + dimsA = len(a) + dimsB = len(b) + ndim = dimsA if dimsA > dimsB else dimsB + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else -1 + sizeB = b[dimB] if dimB >= 0 else -1 + expandedSizes[i] = sizeA if sizeB == -1 else sizeB + + return expandedSizes + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut) + updates_shape = get_shape(node.args[2]) + input_shape = get_shape(node.args[0]) + new_shape = [] + if len(updates_shape) != 1 and len(updates_shape) != len(input_shape): + new_shape = self.infer_sizes(input_shape, updates_shape) + mps_node.mpsnode_union.values_shape = new_shape + + tensors = cast(List[torch.fx.Node], node.args[1]) + for tensor in tensors: + mps_node.mpsnode_union.indices_id.append( + self.define_tensor(tensor, mps_graph) + ) + + mps_node.mpsnode_union.values_id = self.define_tensor( + get_input_node(node, 2), mps_graph + ) + mps_graph.mps_nodes.append(mps_node) + + @register_node_visitor class EmbeddingVisitor(NodeVisitor): target = "aten.embedding.default" diff --git a/backends/apple/mps/operators/unary_ops.py b/backends/apple/mps/operators/unary_ops.py index 411924d040..8b67d7dfba 100644 --- a/backends/apple/mps/operators/unary_ops.py +++ b/backends/apple/mps/operators/unary_ops.py @@ -30,6 +30,7 @@ MPSLog, MPSLog10, MPSLog2, + MPSLogicalNot, MPSNeg, MPSReciprocal, MPSRound, @@ -79,6 +80,7 @@ class UnaryOpVisitor(NodeVisitor): "aten.isnan.default", "aten.isinf.default", "aten.round.default", + "aten.logical_not.default", ] def __init__(self, *args) -> None: @@ -115,6 +117,7 @@ def __init__(self, *args) -> None: exir_ops.edge.aten.isnan.default: MPSIsnan, exir_ops.edge.aten.isinf.default: MPSIsinf, exir_ops.edge.aten.round.default: MPSRound, + exir_ops.edge.aten.logical_not.default: MPSLogicalNot, } def define_node( diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py index a06677a59a..3dfc73cdd9 100644 --- a/backends/apple/mps/partition/mps_partitioner.py +++ b/backends/apple/mps/partition/mps_partitioner.py @@ -4,12 +4,13 @@ # import logging -from typing import Any, Dict, List, Union +from typing import Any, cast, Dict, List, Union import torch from executorch.backends.apple.mps.mps_preprocess import MPSBackend from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors from executorch.backends.apple.mps.utils.mps_utils import is_parameter +from executorch.backends.transforms import get_shape from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( generate_partitions_from_list_of_nodes, @@ -20,6 +21,7 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -28,6 +30,13 @@ logging.basicConfig(level=logging.DEBUG, format=FORMAT) +# ops implemented as Metal kernels. +METAL_KERNELS = [ + exir_ops.edge.aten.index.Tensor, + exir_ops.edge.aten.index_put.default, +] + + class MPSOperatorSupport(OperatorSupportBase): def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): self.node_visitors = get_node_visitors(edge_program) @@ -65,10 +74,47 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: op_support=self.supported_ops, ) + def mps_graph_advanced_indexing_support(self, node: torch.fx.Node): + num_indices = 0 + tensors = cast(List[torch.fx.Node], node.args[1]) + input = cast(torch.fx.Node, node.args[0]) + for t in tensors: + if t is not None: + num_indices += 1 + # Can dispatch to MPSGraph if the length of the slices is equal + # to the number of dimensions of the sliced tensors, or only one + # slice is present. All other cases will fallback to a Metal kernel. + if num_indices == len(get_shape(input)) or num_indices == 1: + return True + + return False + + def use_metal_kernel(self, node: torch.fx.Node): + if node.target in METAL_KERNELS: + if ( + node.target == exir_ops.edge.aten.index.Tensor + or node.target == exir_ops.edge.aten.index_put.default + ): + if not self.mps_graph_advanced_indexing_support(node): + return True + return False + def tag_nodes(self, partitions: List[Partition]) -> None: for partition in partitions: - for node in partition.nodes: + crt_partition_counter = 0 + for node in sorted(partition.nodes): delegation_tag = f"mps_{partition.id}" + if self.use_metal_kernel(node): + logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!") + # Partition the Metal kernel into a separate partition + crt_partition_counter += 1 + delegation_tag = ( + f"{delegation_tag}_metal_kernel_{crt_partition_counter}" + ) + crt_partition_counter += 1 + else: + delegation_tag = f"{delegation_tag}_{crt_partition_counter}" + node.meta["delegation_tag"] = delegation_tag self.partition_tags[delegation_tag] = self.delegation_spec diff --git a/backends/apple/mps/runtime/MPSDevice.h b/backends/apple/mps/runtime/MPSDevice.h index d9ab403e80..a8b5dbe2b8 100644 --- a/backends/apple/mps/runtime/MPSDevice.h +++ b/backends/apple/mps/runtime/MPSDevice.h @@ -5,10 +5,19 @@ #pragma once +// Obj-C headers #include #include + +// Runtime headers +#include + +// MPS headers #include +#include +#include + #define MB(x) (x * 1048576UL) namespace torch { @@ -25,6 +34,11 @@ enum class MacOSVersion : uint32_t { MACOS_VER_14_0_PLUS, }; +enum class LibraryType : uint32_t { + INDEXING_KERNELS = 0, + MAX = INDEXING_KERNELS, +}; + class MPSDevice { public: /** @@ -53,9 +67,18 @@ class MPSDevice { ~MPSDevice(); + /** + * Compile a PSO for a given library type. + * Once compiled, the library and PSOs are cached. + */ + Error compilePSO(LibraryType libraryType, const char* kernelName); + Error compileLibrary(LibraryType); + private: static MPSDevice* _device; id _mtl_device; + std::unordered_map> _m_library_cache; + std::unordered_map> _m_pso_cache; MPSDevice(); }; diff --git a/backends/apple/mps/runtime/MPSDevice.mm b/backends/apple/mps/runtime/MPSDevice.mm index 86518fd002..f51851c379 100644 --- a/backends/apple/mps/runtime/MPSDevice.mm +++ b/backends/apple/mps/runtime/MPSDevice.mm @@ -16,6 +16,20 @@ static std::unique_ptr mps_device; static std::once_flag mpsdev_init; +static inline MTLLanguageVersion getMetalLanguageVersion(const id& device, bool macOS13Plus) { + // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) + // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+) + MTLLanguageVersion languageVersion = MTLLanguageVersion2_3; +#if defined(__MAC_13_0) + if (macOS13Plus) { + languageVersion = MTLLanguageVersion3_0; + } +#endif + + ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); + return languageVersion; +} + MPSDevice::~MPSDevice() { [_mtl_device release]; _mtl_device = nil; @@ -79,6 +93,57 @@ } } +const char* getLibraryCString(LibraryType libraryType) { + switch (libraryType) { + case LibraryType::INDEXING_KERNELS: + return "TODO"; + default: + ET_CHECK_MSG(false, "Unhandled library type!"); + } +} + +Error +MPSDevice::compileLibrary(LibraryType libraryType) { + Error err = Error::Ok; + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))]; + [options setFastMathEnabled:YES]; + id lib = + [_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType) + encoding:NSASCIIStringEncoding] + options:options + error:&error]; + + ET_CHECK_OR_RETURN_ERROR( + lib != nil, + Internal, + "Failed to create indexing library, error: %s", [[error description] UTF8String] + ); + + _m_library_cache[libraryType] = lib; + return err; +} + +Error +MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) { + Error err = Error::Ok; + if (_m_library_cache.find(libraryType) == _m_library_cache.end()) { + ET_LOG(Debug, "Compiling library type: %d", libraryType); + err = compileLibrary(libraryType); + ET_CHECK_OR_RETURN_ERROR( + err == Error::Ok, + Internal, + "An error occured occured while compiling library %d", libraryType + ); + } + if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) { + ET_LOG(Debug, "Compiling kernel: %s", kernelName); + // err = compilePSO(libraryType, kernelName); + } + return err; +} + bool isMacOS13OrNewer(MacOSVersion version) { return MPSDevice::getInstance()->isMacOS13Plus(version); } diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.h b/backends/apple/mps/runtime/MPSGraphBuilder.h index 0a7bf835a7..e4e89d6869 100644 --- a/backends/apple/mps/runtime/MPSGraphBuilder.h +++ b/backends/apple/mps/runtime/MPSGraphBuilder.h @@ -109,6 +109,7 @@ class MPSGraphBuilder { _DEFINE_MPS_OP(Isnan); _DEFINE_MPS_OP(Isinf); _DEFINE_MPS_OP(Round); + _DEFINE_MPS_OP(LogicalNot); _DEFINE_MPS_OP(NormCdf); // Clamp ops _DEFINE_MPS_OP(Clamp); @@ -120,6 +121,8 @@ class MPSGraphBuilder { // Indexing ops _DEFINE_MPS_OP(IndexSelect); _DEFINE_MPS_OP(Embedding); + _DEFINE_MPS_OP(IndexTensor); + _DEFINE_MPS_OP(IndexPut); // Linear algebra ops _DEFINE_MPS_OP(MatMul); _DEFINE_MPS_OP(Addmm); @@ -153,6 +156,7 @@ class MPSGraphBuilder { // Helper functions Error addNodeToMPSGraph(NodePtr nodePtr); + Error compileMetalKernel(NodePtr nodePtr); MPSShape *getMPSShape(int32_t id); MPSShape *getMPSShape(const flatbuffers::Vector *shape); int64_t numel(const flatbuffers::Vector *shape); @@ -161,6 +165,8 @@ class MPSGraphBuilder { MPSGraphTensor *getMPSGraphTensor(int32_t id); NSData *getConstantData(int32_t id); std::pair getMinMaxValues(NodePtr nodePtr); + Error compileMPSGraph(); + Error compileMetalKernel(); // Each MPSGraph op result in at least MPSGraphTensor being // produced, which will be stored in this structure. Other ops @@ -172,6 +178,7 @@ class MPSGraphBuilder { // FlatBuffer raw bytes of the serialized MPS model. const void *_buffer_pointer; + bool _metal_kernel; MPSGraph *_mpsGraph; MPSGraphExecutable *_mpsGraphExecutable; NSMutableDictionary *_feeds; diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.mm b/backends/apple/mps/runtime/MPSGraphBuilder.mm index d82b677066..8b571001d4 100644 --- a/backends/apple/mps/runtime/MPSGraphBuilder.mm +++ b/backends/apple/mps/runtime/MPSGraphBuilder.mm @@ -17,6 +17,7 @@ _targetTensors = [NSMutableArray new]; _mpsGraphExecutable = nil; + _metal_kernel = false; } Error @@ -32,8 +33,34 @@ mpsgraph::MPSGraphIdentifier()); _flatBufferGraph = mpsgraph::GetMPSGraph(_buffer_pointer); - _idToMPSGraphTensor.resize(_flatBufferGraph->mps_values()->size(), nullptr); + switch (_flatBufferGraph->graph_type()) { + case mpsgraph::OpType::metal_kernel: + { + _metal_kernel = true; + err = compileMetalKernel(); + break; + } + case mpsgraph::OpType::mps_graph: + { + err = compileMPSGraph(); + break; + } + default: + ET_CHECK_OR_RETURN_ERROR( + false, + DelegateInvalidCompatibility, + "Received an invalid operation type: expected MPSGraph or metal kernel, but got: %s", + EnumNameOpType(_flatBufferGraph->graph_type())); + } + + return err; +} +Error +MPSGraphBuilder::compileMPSGraph() { + Error err = Error::Ok; + + _idToMPSGraphTensor.resize(_flatBufferGraph->mps_values()->size(), nullptr); // Add the placeholder nodes to the graph. for (auto in_id : *_flatBufferGraph->input_ids()) { err = mpsGraphRankedPlaceholder(in_id); @@ -71,6 +98,30 @@ return err; } +Error +MPSGraphBuilder::compileMetalKernel() { + Error err = Error::Ok; + + ET_CHECK_OR_RETURN_ERROR( + _flatBufferGraph->mps_nodes()->size() == 1, + DelegateInvalidCompatibility, + "Currently supporting dispatching a single Metal kernel."); + ET_CHECK_OR_RETURN_ERROR( + _flatBufferGraph->constant_ids()->size() == 0, + DelegateInvalidCompatibility, + "Currently not supporting dispatching Metal kernels with constants."); + + // Compile the corresponding Metal kernel + for (auto node : *_flatBufferGraph->mps_nodes()) { + err = compileMetalKernel(node); + if (err != Error::Ok) { + return err; + } + } + + return err; +} + Error MPSGraphBuilder::mpsGraphRankedPlaceholder(int32_t id) { ET_LOG(Debug, "%s: %d", __FUNCTION__, id); diff --git a/backends/apple/mps/runtime/operations/IndexingOps.mm b/backends/apple/mps/runtime/operations/IndexingOps.mm index 1c02cbea5c..b4dcf192b4 100644 --- a/backends/apple/mps/runtime/operations/IndexingOps.mm +++ b/backends/apple/mps/runtime/operations/IndexingOps.mm @@ -108,6 +108,102 @@ return Error::Ok; } +Error +MPSGraphBuilder::mpsIndexTensorOp(NodePtr nodePtr) { + Error err = Error::Ok; + auto graphNode = nodePtr->mpsnode_union_as_MPSIndexTensor(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + if (_metal_kernel) { + err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_select"); + ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n"); + } else { + int validIndices = 0; + int numIndices = graphNode->indices_id()->size(); + int axis = -1; + int indexId = -1; + for (int i = 0; i < numIndices; i++) { + int32_t index_id = graphNode->indices_id()->Get(i); + if (index_id == -1) { + continue; + } + validIndices++; + axis = i; + indexId = index_id; + } + ET_LOG(Debug, "index.Tensor with %d indices (axis = %d)", validIndices, axis); + ET_CHECK(validIndices > 0); + + if (validIndices == 1) { + MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indexTensor = getMPSGraphTensor(indexId); + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph gatherWithUpdatesTensor:updatesTensor indicesTensor:indexTensor axis:axis batchDimensions:0 name:nil]; + } else { + ET_CHECK_MSG(false, "Not yet implemented"); + } + } + + return err; +} + +Error +MPSGraphBuilder::mpsIndexPutOp(NodePtr nodePtr) { + Error err = Error::Ok; + auto graphNode = nodePtr->mpsnode_union_as_MPSIndexPut(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + if (_metal_kernel) { + err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_put"); + ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n"); + } else { + int validIndices = 0; + int numIndices = graphNode->indices_id()->size(); + int axis = -1; + int indexId = -1; + for (int i = 0; i < numIndices; i++) { + int32_t index_id = graphNode->indices_id()->Get(i); + if (index_id == -1) { + continue; + } + validIndices++; + axis = i; + indexId = index_id; + } + ET_LOG(Debug, "index_put with %d indices (axis = %d)", validIndices, axis); + ET_CHECK(validIndices > 0); + + if (validIndices == 1) { + MPSGraphTensor* dataTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->values_id()); + MPSGraphTensor* indicesTensor = getMPSGraphTensor(indexId); + if (graphNode->values_shape()->size() != 0) { + updatesTensor = [_mpsGraph broadcastTensor:updatesTensor + toShape:getMPSShape(graphNode->values_shape()) + name:nil]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph scatterWithDataTensor:dataTensor + updatesTensor:updatesTensor + indicesTensor:indicesTensor + axis:axis + mode:MPSGraphScatterModeSet + name:nil]; + } else { + ET_CHECK_MSG(false, "Not yet implemented"); + } + } + + return err; +} + } // namespace delegate } // namespace mps } // namespace executor diff --git a/backends/apple/mps/runtime/operations/OperationUtils.mm b/backends/apple/mps/runtime/operations/OperationUtils.mm index 71c36c967e..648421ee2c 100644 --- a/backends/apple/mps/runtime/operations/OperationUtils.mm +++ b/backends/apple/mps/runtime/operations/OperationUtils.mm @@ -166,6 +166,7 @@ _DEFINE_MPS_NODE(Isnan); _DEFINE_MPS_NODE(Isinf); _DEFINE_MPS_NODE(Round); + _DEFINE_MPS_NODE(LogicalNot); // Clamp ops _DEFINE_MPS_NODE(Clamp); _DEFINE_MPS_NODE(Where); @@ -178,6 +179,8 @@ //Indexing ops _DEFINE_MPS_NODE(IndexSelect); _DEFINE_MPS_NODE(Embedding); + _DEFINE_MPS_NODE(IndexTensor); + _DEFINE_MPS_NODE(IndexPut); // Reduce ops _DEFINE_MPS_NODE(Mean); // Shape ops @@ -223,6 +226,11 @@ } } +Error +MPSGraphBuilder::compileMetalKernel(NodePtr nodePtr) { + return addNodeToMPSGraph(nodePtr); +} + #undef _DEFINE_MPS_NODE MPSGraphTensor* diff --git a/backends/apple/mps/runtime/operations/ShapeOps.mm b/backends/apple/mps/runtime/operations/ShapeOps.mm index 720161b955..75de566e4a 100644 --- a/backends/apple/mps/runtime/operations/ShapeOps.mm +++ b/backends/apple/mps/runtime/operations/ShapeOps.mm @@ -42,13 +42,9 @@ __FUNCTION__, graphNode->input1_id(), graphNode->output_id() ); - NSMutableArray* shape = [NSMutableArray array]; - for (int32_t i = 0; i < graphNode->num_dims(); i++) { - [shape addObject:[NSNumber numberWithInteger:graphNode->shape()->Get(i)]]; - } _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph reshapeTensor:getMPSGraphTensor(graphNode->input1_id()) - withShape:shape + withShape:getMPSShape(graphNode->shape()) name:@"view_copy"]; return Error::Ok; @@ -91,7 +87,7 @@ __FUNCTION__, graphNode->output_id() ); - NSMutableArray* inputTensors = [NSMutableArray array]; + NSMutableArray* inputTensors = [NSMutableArray arrayWithCapacity:graphNode->input_ids()->size()];; for (auto id : *graphNode->input_ids()) { MPSGraphTensor* catTensor = getMPSGraphTensor(id); if (catTensor != nil) diff --git a/backends/apple/mps/runtime/operations/UnaryOps.mm b/backends/apple/mps/runtime/operations/UnaryOps.mm index 31246bd44f..ed06584b27 100644 --- a/backends/apple/mps/runtime/operations/UnaryOps.mm +++ b/backends/apple/mps/runtime/operations/UnaryOps.mm @@ -92,6 +92,7 @@ REGISTER_UNARY_OP(Isnan, isNaN) REGISTER_UNARY_OP(Isinf, isInfinite) REGISTER_UNARY_OP(Round, round) +REGISTER_UNARY_OP(LogicalNot, not) Error diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py index 66697b04b7..8134091a01 100644 --- a/backends/apple/mps/serialization/mps_graph_schema.py +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -27,6 +27,11 @@ class MPSDataType(IntEnum): mps_data_type_complex_float32 = 11 +class OpType(IntEnum): + mps_graph = 0 + metal_kernel = 1 + + @dataclass class MPSNode1x1: input1_id: int @@ -359,6 +364,11 @@ class MPSRound(MPSNode1x1): pass +@dataclass +class MPSLogicalNot(MPSNode1x1): + pass + + @dataclass class MPSBitwise(MPSNode1x1): pass @@ -434,6 +444,18 @@ class MPSEmbedding(MPSNode2x1): sparse: bool = False +@dataclass +class MPSIndexTensor(MPSNode1x1): + indices_id: List[int] = field(default_factory=list) + + +@dataclass +class MPSIndexPut(MPSNode1x1): + indices_id: List[int] = field(default_factory=list) + values_shape: List[int] = field(default_factory=list) + values_id: int = -1 + + ## ## Shape ops ## @@ -664,6 +686,7 @@ class MPSArange: MPSIsnan, MPSIsinf, MPSRound, + MPSLogicalNot, # Linear algebra ops MPSMatMul, MPSAddmm, @@ -678,6 +701,8 @@ class MPSArange: # Indexing ops MPSIndexSelect, MPSEmbedding, + MPSIndexTensor, + MPSIndexPut, # Shape ops MPSPermute, MPSView, @@ -741,3 +766,4 @@ class MPSGraph: input_ids: List[int] output_ids: List[int] constant_ids: List[int] + graph_type: OpType diff --git a/backends/apple/mps/serialization/schema.fbs b/backends/apple/mps/serialization/schema.fbs index c3e3eaa4fa..6ba2c937f3 100644 --- a/backends/apple/mps/serialization/schema.fbs +++ b/backends/apple/mps/serialization/schema.fbs @@ -24,6 +24,13 @@ enum MPSDataType : short { mps_data_type_complex_float32 = 11, } +// ops like index.Tensor and index.put are currentely implemented as +// Metal kernels for unsupported MPSGraph cases. +enum OpType : short { + mps_graph, + metal_kernel +} + // Helper classes to define the number of input and output tensors for a node. // Not meant to be used directly. @@ -145,6 +152,20 @@ table MPSEmbedding { sparse:bool; } +table MPSIndexTensor { + input1_id:int; + indices_id:[int]; + output_id:int; +} + +table MPSIndexPut { + input1_id:int; + indices_id:[int]; + values_shape:[int]; + values_id:int; + output_id:int; +} + // Shape ops. table MPSPermute { input1_id:int; @@ -350,6 +371,7 @@ union MPSNodeUnion { MPSIsnan: _MPSNode1x1, MPSIsinf: _MPSNode1x1, MPSRound: _MPSNode1x1, + MPSLogicalNot: _MPSNode1x1, // Linear algebra ops MPSMatMul: _MPSNode2x1, @@ -366,6 +388,8 @@ union MPSNodeUnion { // Indexing ops MPSIndexSelect, MPSEmbedding, + MPSIndexTensor, + MPSIndexPut, // Reduce ops MPSMean, @@ -438,6 +462,8 @@ table MPSGraph { input_ids:[int]; output_ids:[int]; constant_ids:[int]; + + graph_type:OpType; } root_type MPSGraph; diff --git a/backends/apple/mps/targets.bzl b/backends/apple/mps/targets.bzl index 94f030310d..4d2862eb72 100644 --- a/backends/apple/mps/targets.bzl +++ b/backends/apple/mps/targets.bzl @@ -22,6 +22,7 @@ def define_common_targets(is_xplat = False, platforms = []): "-Wno-unused-const-variable", "-Wno-unused-variable", "-fno-objc-arc", + "-std=c++17", ], "deps": [ "//executorch/runtime/core:core",