diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index 9e2fe13970731..e5aff51f990c2 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -15,8 +15,10 @@ python -m iree.compiler.tools.import_onnx ... """ import argparse +import os from pathlib import Path import sys +import tempfile try: import onnx @@ -38,8 +40,8 @@ ) -def main(args): - model_proto = load_onnx_model(args.input_file) +def main(args: argparse.Namespace): + model_proto = load_onnx_model(args) context = Context() model_info = onnx_importer.ModelInfo(model_proto) m = model_info.create_module(context=context).operation @@ -58,13 +60,56 @@ def main(args): print(m.get_asm(assume_verified=not args.no_verify)) -def load_onnx_model(file_path: Path) -> onnx.ModelProto: - raw_model = onnx.load(file_path) - inferred_model = onnx.shape_inference.infer_shapes(raw_model) - return inferred_model +def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: + input_dir = os.path.dirname(os.path.abspath(args.input_file)) - -def parse_arguments(argv=None): + # Load the model, with possible external data coming from the default + # location, or the location specified on the command line. + if args.data_dir is None: + raw_model = onnx.load(args.input_file) + else: + raw_model = onnx.load(args.input_file, load_external_data=False) + onnx.load_external_data_for_model(raw_model, args.data_dir) + + # Do shape inference two ways. First, attempt in-memory to avoid redundant + # loading and the need for writing a temporary file somewhere. If that + # fails, typically because of the 2 GB protobuf size limit, try again via + # files. See + # https://onnx.ai/onnx/repo-docs/PythonAPIOverview.html#shape-inference-a-large-onnx-model-2gb + # for details about the file-based technique. + + # Run the checker to test whether the file is above the threshold for + # in-memory shape inference. If not, go ahead and do the shape inference. + try: + onnx.checker.check_model(raw_model) + inferred_model = onnx.shape_inference.infer_shapes( + raw_model, data_prop=args.data_prop + ) + return inferred_model + except ValueError: + pass + + # Model is too big for in-memory inference: do file-based shape inference + # to a temp file. + # Make a temp dir for all the temp files we'll be generating as a side + # effect of infering shapes. For now, the only file is a new .onnx holding + # the revised model with shapes. + with tempfile.TemporaryDirectory(dir=input_dir) as temp_dir_name: + temp_dir_path = Path(temp_dir_name) + temp_inferred_file = temp_dir_path / "temp-inferred.onnx" + onnx.shape_inference.infer_shapes_path( + args.input_file, temp_inferred_file, data_prop=args.data_prop + ) + + # Load the temp file and the external data. + inferred_model = onnx.load(temp_inferred_file, load_external_data=False) + data_dir = Path(input_dir if args.data_dir is None else args.data_dir) + onnx.load_external_data_for_model(inferred_model, data_dir) + + return inferred_model + + +def parse_arguments(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="IREE ONNX import tool") parser.add_argument("input_file", help="ONNX protobuf input", type=Path) parser.add_argument( @@ -75,6 +120,18 @@ def parse_arguments(argv=None): action="store_true", help="Disable verification prior to printing", ) + parser.add_argument( + "--data-prop", + default=True, + action=argparse.BooleanOptionalAction, + help="Toggle data propogation for onnx shape inference", + ) + parser.add_argument( + "--data-dir", + help="Path to the base directory of the data." + " Defaults to the directory of the input file.", + type=Path, + ) args = parser.parse_args(argv) return args