-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add functions that prepares tensors as input args for IREE tools
One drawback of using npy files is that they don't support some datatypes. This change adds functionality to prepare arguments from Torch tensors in the form --input=1x2xbf16=@arg0.bin These can then be passed to tools like iree-run-module and iree-benchmark-module. Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
- Loading branch information
Showing
4 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
|
||
from .conversions import torch_dtyped_shape_to_iree_format | ||
|
||
|
||
def iree_tool_format_cli_input_arg(arg: torch.Tensor, file_path: str) -> str: | ||
"""Format the CLI value for an input argument. | ||
Example: | ||
iree_tool_format_cli_input_arg(torch.empty([1,2], dtype=torch.float32), "arg0.bin") | ||
Returns: | ||
"1x2xf32=@arg0.bin" | ||
""" | ||
return f"{torch_dtyped_shape_to_iree_format(arg)}=@{file_path}" | ||
|
||
|
||
def write_raw_tensor(tensor: torch.Tensor, file_path: str): | ||
"""Write the contents of the tensor as they are in memory without any metadata.""" | ||
with open(file_path, "wb") as f: | ||
f.write(tensor.cpu().view(dtype=torch.int8).numpy().data) | ||
|
||
|
||
def iree_tool_prepare_input_args( | ||
args: tuple[torch.Tensor], | ||
/, | ||
*, | ||
file_paths: tuple[str] | None = None, | ||
file_path_prefix: str | None = None, | ||
) -> tuple[str]: | ||
"""Write the raw contents of tensors to files without any metadata. | ||
Returns the CLI input args description. | ||
If file_path_prefix will chose a default naming for argument files. | ||
Example: | ||
file_path_prefix="/some/path_arg" | ||
returns | ||
[ | ||
"1x2x3xf32=@/some/path_arg0.bin", | ||
"4x5xi8=@/some/path_arg1.bin" | ||
] | ||
This results can be prefixed with "--input=" to arrive at the final CLI flags | ||
expected by IREE tools. | ||
""" | ||
if file_paths is not None and file_path_prefix is not None: | ||
raise ValueError( | ||
"file_paths and file_path_prefix are mutually exclusive arguments." | ||
) | ||
|
||
if file_path_prefix is not None: | ||
file_paths = [f"{file_path_prefix}{i}.bin" for i in range(len(args))] | ||
for tensor, file_path in zip(args, file_paths): | ||
write_raw_tensor(tensor, file_path) | ||
return [ | ||
iree_tool_format_cli_input_arg(tensor, file_path) | ||
for tensor, file_path in zip(args, file_paths) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
from iree.turbine.support.conversions import torch_dtyped_shape_to_iree_format | ||
|
||
|
||
def test_torch_dtyped_shape_to_iree_format(): | ||
iree_format = torch_dtyped_shape_to_iree_format([1, 2, 3], dtype=torch.bfloat16) | ||
assert iree_format == "1x2x3xbf16" | ||
|
||
|
||
def test_torch_dtyped_zero_rank_shape_to_iree_format(): | ||
iree_format = torch_dtyped_shape_to_iree_format([], dtype=torch.float8_e4m3fn) | ||
assert iree_format == "f8E4M3FN" | ||
|
||
|
||
def test_torch_dtyped_shape_to_iree_format_from_tensor(): | ||
tensor = torch.empty([1, 2, 3], dtype=torch.float32) | ||
iree_format = torch_dtyped_shape_to_iree_format(tensor) | ||
assert iree_format == "1x2x3xf32" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2025 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
from iree.turbine.support.tools import iree_tool_prepare_input_args | ||
import tempfile | ||
from pathlib import Path | ||
|
||
|
||
def test_iree_tool_prepare_input_args(): | ||
arg0 = torch.tensor([1.1, 2.2, 3.3, 4.4], dtype=torch.bfloat16) | ||
arg1 = torch.tensor([[4, 5], [6, 7]], dtype=torch.int8) | ||
args = [arg0, arg1] | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
cli_arg_values = iree_tool_prepare_input_args( | ||
args, file_path_prefix=str(Path(tmp_dir) / "arg") | ||
) | ||
|
||
expected_arg_file_paths = [ | ||
str(Path(tmp_dir) / "arg0.bin"), | ||
str(Path(tmp_dir) / "arg1.bin"), | ||
] | ||
|
||
assert cli_arg_values[0] == f"4xbf16=@{expected_arg_file_paths[0]}" | ||
assert cli_arg_values[1] == f"2x2xi8=@{expected_arg_file_paths[1]}" | ||
|
||
for arg, file_path in zip(args, expected_arg_file_paths): | ||
with open(file_path, "rb") as f: | ||
actual_bytes = f.read() | ||
assert ( | ||
arg.cpu().view(dtype=torch.int8).numpy().tobytes() == actual_bytes | ||
) |