Skip to content

Commit

Permalink
Add functions that prepares tensors as input args for IREE tools
Browse files Browse the repository at this point in the history
One drawback of using npy files is that they don't support some
datatypes.

This change adds functionality to prepare arguments from Torch tensors
in the form
--input=1x2xbf16=@arg0.bin
These can then be passed to tools like iree-run-module and
iree-benchmark-module.

Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
  • Loading branch information
sogartar committed Jan 27, 2025
1 parent 1847c33 commit 2a159db
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
18 changes: 18 additions & 0 deletions iree/turbine/support/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,21 @@ def torch_dtype_to_numpy(torch_dtype: torch.dtype) -> Any:
return TORCH_DTYPE_TO_NUMPY[torch_dtype]
except KeyError:
raise UnknownDTypeError(torch_dtype)


def torch_dtyped_shape_to_iree_format(
shape_or_tensor: tuple[int] | torch.Tensor, /, dtype: torch.dtype | None = None
) -> str:
"""Example:
shape = [1, 2, 3]
dtype = torch.bfloat16
Returns
"1x2x3xbf16"
"""
if isinstance(shape_or_tensor, torch.Tensor):
dtype = dtype or shape_or_tensor.dtype
return torch_dtyped_shape_to_iree_format(shape_or_tensor.shape, dtype)
shape_str = "x".join([str(d) for d in shape_or_tensor])
shape_dtype_delimiter = "x" if len(shape_or_tensor) > 0 else ""
dtype_str = TORCH_DTYPE_TO_IREE_TYPE_ASM[dtype]
return f"{shape_str}{shape_dtype_delimiter}{dtype_str}"
62 changes: 62 additions & 0 deletions iree/turbine/support/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from .conversions import torch_dtyped_shape_to_iree_format


def iree_tool_format_cli_input_arg(arg: torch.Tensor, file_path: str) -> str:
"""Format the CLI value for an input argument.
Example:
iree_tool_format_cli_input_arg(torch.empty([1,2], dtype=torch.float32), "arg0.bin")
Returns:
"1x2xf32=@arg0.bin"
"""
return f"{torch_dtyped_shape_to_iree_format(arg)}=@{file_path}"


def write_raw_tensor(tensor: torch.Tensor, file_path: str):
"""Write the contents of the tensor as they are in memory without any metadata."""
with open(file_path, "wb") as f:
f.write(tensor.cpu().view(dtype=torch.int8).numpy().data)


def iree_tool_prepare_input_args(
args: tuple[torch.Tensor],
/,
*,
file_paths: tuple[str] | None = None,
file_path_prefix: str | None = None,
) -> tuple[str]:
"""Write the raw contents of tensors to files without any metadata.
Returns the CLI input args description.
If file_path_prefix will chose a default naming for argument files.
Example:
file_path_prefix="/some/path_arg"
returns
[
"1x2x3xf32=@/some/path_arg0.bin",
"4x5xi8=@/some/path_arg1.bin"
]
This results can be prefixed with "--input=" to arrive at the final CLI flags
expected by IREE tools.
"""
if file_paths is not None and file_path_prefix is not None:
raise ValueError(
"file_paths and file_path_prefix are mutually exclusive arguments."
)

if file_path_prefix is not None:
file_paths = [f"{file_path_prefix}{i}.bin" for i in range(len(args))]
for tensor, file_path in zip(args, file_paths):
write_raw_tensor(tensor, file_path)
return [
iree_tool_format_cli_input_arg(tensor, file_path)
for tensor, file_path in zip(args, file_paths)
]
24 changes: 24 additions & 0 deletions tests/support/conversions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from iree.turbine.support.conversions import torch_dtyped_shape_to_iree_format


def test_torch_dtyped_shape_to_iree_format():
iree_format = torch_dtyped_shape_to_iree_format([1, 2, 3], dtype=torch.bfloat16)
assert iree_format == "1x2x3xbf16"


def test_torch_dtyped_zero_rank_shape_to_iree_format():
iree_format = torch_dtyped_shape_to_iree_format([], dtype=torch.float8_e4m3fn)
assert iree_format == "f8E4M3FN"


def test_torch_dtyped_shape_to_iree_format_from_tensor():
tensor = torch.empty([1, 2, 3], dtype=torch.float32)
iree_format = torch_dtyped_shape_to_iree_format(tensor)
assert iree_format == "1x2x3xf32"
35 changes: 35 additions & 0 deletions tests/support/tools_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from iree.turbine.support.tools import iree_tool_prepare_input_args
import tempfile
from pathlib import Path


def test_iree_tool_prepare_input_args():
arg0 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.bfloat16)
arg1 = torch.tensor([[4, 5], [6, 7]], dtype=torch.int8)
args = [arg0, arg1]
with tempfile.TemporaryDirectory() as tmp_dir:
cli_arg_values = iree_tool_prepare_input_args(
args, file_path_prefix=str(Path(tmp_dir) / "arg")
)

expected_arg_file_paths = [
str(Path(tmp_dir) / "arg0.bin"),
str(Path(tmp_dir) / "arg1.bin"),
]

assert cli_arg_values[0] == f"4xbf16=@{expected_arg_file_paths[0]}"
assert cli_arg_values[1] == f"2x2xi8=@{expected_arg_file_paths[1]}"

for arg, file_path in zip(args, expected_arg_file_paths):
with open(file_path, "rb") as f:
actual_bytes = f.read()
assert (
arg.cpu().view(dtype=torch.int8).numpy().tobytes() == actual_bytes
)

0 comments on commit 2a159db

Please sign in to comment.