diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 67d77d4b47..ba435e62bb 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -17,7 +17,7 @@ ) from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - Buffer, + ConstantDataOffset, PerChannelQuant, PerTensorQuant, PerTokenDynamicQuant, @@ -43,6 +43,12 @@ torch.float32: XNNDatatype.xnn_datatype_fp32, } +from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( + _aligned_size, + _pad_to, + CONSTANT_TENSOR_ALIGNMENT, +) + class InputTypeToIndex: """ @@ -78,9 +84,11 @@ def __init__( self, exported_program: ExportedProgram, external_ids: Dict, + constant_data_bytes: bytearray, ) -> None: self._external_ids = external_ids or {} self._exported_program = exported_program or None + self._constant_data_bytes = constant_data_bytes @property def external_ids(self) -> Dict: @@ -317,7 +325,7 @@ def define_tensor( dims = [1] if len(dims) == 0 else dims # constant values serialize data - buffer_idx = self.get_serialized_buffer( + buffer_idx = self.get_serialized_buffer_index( tensor, xnn_graph, vals_to_ids, @@ -426,7 +434,7 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor: return result - def get_serialized_buffer( + def get_serialized_buffer_index( self, tensor: torch.fx.Node, xnn_graph: XNNGraph, @@ -469,11 +477,7 @@ def get_serialized_buffer( ) return 0 - check_or_raise( - len(xnn_graph.constant_buffer) == len(xnn_graph.mem_buffer_sizes), - "Internal Error: const_buffer and buffer_sizes length mismatch", - ) - buffer_idx = len(xnn_graph.constant_buffer) + buffer_idx = len(xnn_graph.constant_data) const_val = get_param_tensor(self.exported_program, get_attr_node) assert const_val is not None and isinstance(const_val, torch.Tensor) const_val = const_val.contiguous() @@ -501,9 +505,13 @@ def get_serialized_buffer( const_val.untyped_storage().data_ptr(), ctypes.POINTER(array_type), ).contents - buffer = Buffer(storage=bytes(array)) - xnn_graph.constant_buffer.append(buffer) - xnn_graph.mem_buffer_sizes.append(const_val.untyped_storage().nbytes()) + + offset = len(self._constant_data_bytes) + size = const_val.untyped_storage().nbytes() + xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size)) + self._constant_data_bytes.extend( + _pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT)) + ) return buffer_idx diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 8a1b46fdd0..f43dd95048 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -36,10 +36,9 @@ union XNNQuantParams { PerTokenDynamicQuant, } -// taken from executorch -// Data buffer abstraction. +// Deprecated buffer abstraction, const data buffers do not belong in flatbuffer table Buffer { - storage:[ubyte] (force_align: 16); + storage:[ubyte] (deprecated, force_align: 16); } table PerChannelQuant { @@ -324,18 +323,14 @@ table XNNGraph { // Ids of external outputs output_ids:[uint]; - // Tables of constant data, used for constant Values (e.g. - // data field of weight tensors). Each constant is assigned an index into the table - // which are each individually aligned. 0 index is reserved to be pointed to by non-constant - // Tensors. Exactly one of constant_buffer and constant_data must be non-empty - constant_buffer:[Buffer]; + // Deprecated constant buffer storage in flatbuffer + constant_buffer:[Buffer] (deprecated); - // the list index is memory buffer id, the value is the memory buffer size. - mem_buffer_sizes: [uint]; + // Deprecated memory_buffer size tracking in flatbuffer + mem_buffer_sizes: [uint] (deprecated); // List of the constant data that follows the XNNGraph in this file. Each constant data is assigned an index into - // the table. 0 index is reserved to be pointed to by non-constant Tensor. Exactly one of constant_buffer and - // constant_data must be non-empty + // the table. 0 index is reserved to be pointed to by non-constant Tensor. constant_data:[ConstantDataOffset]; } diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index b164a96194..3222202ea8 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -431,11 +431,6 @@ class XValue: xvalue_union: "XValueUnion" -@dataclass -class Buffer: - storage: bytes - - @dataclass class ConstantDataOffset: offset: int @@ -452,7 +447,4 @@ class XNNGraph: input_ids: List[int] output_ids: List[int] - constant_buffer: List[Buffer] - mem_buffer_sizes: List[int] - constant_data: List[ConstantDataOffset] diff --git a/backends/xnnpack/serialization/xnnpack_graph_serialize.py b/backends/xnnpack/serialization/xnnpack_graph_serialize.py index f8a7ae77c6..160c926780 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_serialize.py +++ b/backends/xnnpack/serialization/xnnpack_graph_serialize.py @@ -9,14 +9,10 @@ import tempfile from dataclasses import dataclass, fields, is_dataclass -from typing import ClassVar, List, Literal, Tuple +from typing import ClassVar, Literal import pkg_resources -from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - Buffer, - ConstantDataOffset, - XNNGraph, -) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph from executorch.exir._serialize._dataclass import _DataclassEncoder from executorch.exir._serialize._flatbuffer import _flatc_compile @@ -26,6 +22,9 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" +# Constant Tensor alignment for serializaing XNNPACK payloads +CONSTANT_TENSOR_ALIGNMENT = 16 + def sanity_check_xnngraph_dataclass(table, name: str = ""): """ @@ -274,40 +273,6 @@ def _pad_to(data: bytes, length: int) -> bytes: return data -def _extract_constant_data( - constant_buffer: List[Buffer], - tensor_alignment: int = 16, -) -> Tuple[bytes, List[int]]: - """Copies the tensors from the provided list into a single buffer and tracks the offsets - of each tensor. - - constant_buffer: list of Buffers from which to extract constants from. Not modified. - tensor_alignment: Alignment in bytes. The starting offset of each tensor in the - constant segment will be aligned to this value. Default to 16. - - Returns: - A tuple of (constant segment, list of offsets for each tensor in the segment) - """ - constant_segment_data: bytearray = bytearray() - constant_segment_offsets: List[int] = [] - current_offset: int = 0 - for i in range(len(constant_buffer)): - buffer = constant_buffer[i] - buffer_length = len(buffer.storage) - pad_length = _padding_required(buffer_length, tensor_alignment) - - # Append each constant buffer to the constant segment. - constant_segment_data += buffer.storage - # Add padding for all but the last tensor. - if i < len(constant_buffer) - 1: - constant_segment_data += b"\x00" * pad_length - - # Append constant data offset. - constant_segment_offsets.append(current_offset) - current_offset += buffer_length + pad_length - return bytes(constant_segment_data), constant_segment_offsets - - def pretty_print_xnngraph(xnnpack_graph_json: str): """ Pretty print the XNNGraph @@ -335,7 +300,9 @@ def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes: return output_file.read() -def serialize_xnnpack_binary(xnnpack_graph: XNNGraph) -> bytes: +def serialize_xnnpack_binary( + xnnpack_graph: XNNGraph, constant_data_bytes: bytearray +) -> bytes: """Returns the runtime binary representation of the given XNNGraph. Args: @@ -344,26 +311,6 @@ def serialize_xnnpack_binary(xnnpack_graph: XNNGraph) -> bytes: Returns: The serialized form of the XNNGraph, ready for execution by XNNPACK Backend """ - constant_tensor_alignment = 16 - - # Extract constant data from the graph - constant_data, constant_data_offsets = _extract_constant_data( - xnnpack_graph.constant_buffer, constant_tensor_alignment - ) - - assert len(constant_data_offsets) == len(xnnpack_graph.mem_buffer_sizes) - - for offset_idx in range(len(constant_data_offsets)): - constant_data_offset = constant_data_offsets[offset_idx] - constant_data_size = xnnpack_graph.mem_buffer_sizes[offset_idx] - xnnpack_graph.constant_data.append( - ConstantDataOffset(constant_data_offset, constant_data_size) - ) - - # We are moving all constant data from the graph to the constant data section. - # So we remove all constant buffers - xnnpack_graph.constant_buffer = [] - xnnpack_graph.mem_buffer_sizes = [] # Convert the XNNGraph to a flatbuffer flatbuffer_payload = convert_to_flatbuffer(xnnpack_graph) @@ -371,12 +318,11 @@ def serialize_xnnpack_binary(xnnpack_graph: XNNGraph) -> bytes: # size of flatbuffer data, padded to be `constant_tensor_alignment` byte aligned padded_flatbuffer_length: int = _aligned_size( input_size=len(flatbuffer_payload), - alignment=constant_tensor_alignment, + alignment=CONSTANT_TENSOR_ALIGNMENT, ) # size of header to insert, padded to be `constant_tensor_alignment` byte aligned padded_header_length: int = _aligned_size( - input_size=XNNHeader.EXPECTED_LENGTH, - alignment=constant_tensor_alignment, + input_size=XNNHeader.EXPECTED_LENGTH, alignment=CONSTANT_TENSOR_ALIGNMENT ) # Create the XNNPACK Header @@ -384,16 +330,13 @@ def serialize_xnnpack_binary(xnnpack_graph: XNNGraph) -> bytes: flatbuffer_offset=padded_header_length, flatbuffer_size=len(flatbuffer_payload), constant_data_offset=padded_header_length + padded_flatbuffer_length, - constant_data_size=len(constant_data), + constant_data_size=len(constant_data_bytes), ).to_bytes() - # Concatenate the header, flatbuffer data, and constant data - # Constant data does not need to be padded to alignment because nothing follows it - return b"".join( [ _pad_to(header, padded_header_length), _pad_to(flatbuffer_payload, padded_flatbuffer_length), - constant_data, + constant_data_bytes, ] ) diff --git a/backends/xnnpack/test/serialization/test_serialization.py b/backends/xnnpack/test/serialization/test_serialization.py index 0ced62e6ca..c2376cc057 100644 --- a/backends/xnnpack/test/serialization/test_serialization.py +++ b/backends/xnnpack/test/serialization/test_serialization.py @@ -4,13 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os -import random import unittest -from typing import List, Tuple from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - Buffer, + ConstantDataOffset, XNNGraph, ) @@ -22,23 +19,6 @@ class TestSerialization(unittest.TestCase): - def _generate_random_const_buffers( - self, num_tensors: int - ) -> Tuple[List[Buffer], List[int]]: - """ - Helper function to generate `num_tensor` buffers of random sizes and random contents, - we return a tuple of (list_of_buffers, list_of_mem_sizes), - """ - buffers = [] - mem_sizes = [] - for _ in range(num_tensors): - buffer_size = random.randint(1, 1000) - buffer = bytearray(os.urandom(buffer_size)) - buffers.append(Buffer(storage=bytes(buffer))) - mem_sizes.append(buffer_size) - - return buffers, mem_sizes - def test_serialize_xnnpack_binary(self): xnn_graph = XNNGraph( version="0", @@ -47,25 +27,18 @@ def test_serialize_xnnpack_binary(self): num_externs=0, input_ids=[], output_ids=[], - constant_buffer=[Buffer(storage=b"")], - mem_buffer_sizes=[0], - constant_data=[], + constant_data=[ConstantDataOffset(0, 0)], ) - buffers, sizes = self._generate_random_const_buffers(5) - xnn_graph.constant_buffer.extend(buffers) - xnn_graph.mem_buffer_sizes.extend(sizes) - buffers = xnn_graph.constant_buffer - serialized_binary = serialize_xnnpack_binary(xnn_graph) - offsets = xnn_graph.constant_data + constant_data_bytes = b"\x00" * 24 + serialized_binary = serialize_xnnpack_binary( + xnn_graph, bytearray(constant_data_bytes) + ) # Check header self.assertEqual(serialized_binary[0:4], b"\x00\x00\x00\x00") self.assertEqual(serialized_binary[XNNHeader.MAGIC_OFFSET], b"XH00") flatbuffer_offset_bytes = serialized_binary[XNNHeader.FLATBUFFER_OFFSET_OFFSET] - constant_data_offset_bytes = serialized_binary[ - XNNHeader.CONSTANT_DATA_OFFSET_OFFSET - ] # Check flatbuffer is at flatbuffer offset flatbuffer_offset = int.from_bytes( @@ -75,24 +48,3 @@ def test_serialize_xnnpack_binary(self): self.assertEqual( serialized_binary[flatbuffer_offset:][XNNHeader.MAGIC_OFFSET], b"XN01" ) - - # Check constant data - # Check that constant buffers have been moved to constant data - self.assertEqual(len(offsets), len(buffers)) - self.assertEqual(len(xnn_graph.constant_buffer), 0) - - constant_data_offset = int.from_bytes( - constant_data_offset_bytes, byteorder=_HEADER_BYTEORDER - ) - constant_data_payload = serialized_binary[constant_data_offset:] - - # We check that constant data indexes stored in the xnn_graph correctly index - # into the correct buffer in the constant data section - for idx in range(1, len(offsets)): - offset = offsets[idx].offset - size = offsets[idx].size - - constant_data_bytes = constant_data_payload[offset : offset + size] - constant_buffer_bytes = buffers[idx].storage - - self.assertEqual(constant_data_bytes, constant_buffer_bytes) diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 1238bddfa8..d852fa604d 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -18,7 +18,7 @@ from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - Buffer, + ConstantDataOffset, XNNGraph, ) from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( @@ -134,12 +134,11 @@ def preprocess( num_externs=len(node_to_external_map), input_ids=[], output_ids=[], - constant_buffer=[Buffer(storage=b"")], - mem_buffer_sizes=[0], - constant_data=[], + constant_data=[ConstantDataOffset(0, 0)], ) - node_visitors = get_node_visitors(ep, node_to_external_map) + constant_data_bytes = bytearray() + node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) for node in graph_module.graph.nodes: if node.op == "call_function": @@ -164,5 +163,8 @@ def preprocess( else: raise RuntimeError(f"{node.op} is not supported in XNNPACK") return PreprocessResult( - processed_bytes=serialize_xnnpack_binary(xnnpack_graph), debug_handle_map={} + processed_bytes=serialize_xnnpack_binary( + xnnpack_graph, constant_data_bytes + ), + debug_handle_map={}, )