diff --git a/tools/pytorch/pnnx_package/pnnx/wrapper.py b/tools/pytorch/pnnx_package/pnnx/wrapper.py index c8c18a3d2e9..dfc275082c4 100644 --- a/tools/pytorch/pnnx_package/pnnx/wrapper.py +++ b/tools/pytorch/pnnx_package/pnnx/wrapper.py @@ -114,11 +114,15 @@ def export(model, inputshape=None, inputshape2=None, **kwargs): current_frame = inspect.currentframe() previous_frame = inspect.getouterframes(current_frame)[1] call_func_code = inspect.getframeinfo(previous_frame[0]).code_context[0].strip() + model_name = get_model_name(call_func_code) - if inputshape is not None and inputshape2 is None: + if inputshape is not None: model_torchscript_path = trace_model(model, model_name, inputshape) - run(model_torchscript_path, inputshape=inputshape, **kwargs) + run(model_torchscript_path, inputshape=inputshape, inputshape2=inputshape2, **kwargs) + else: + print("inputshape is None, which is required") + exit(1) print("[-] leaving export") \ No newline at end of file