diff --git a/src/setfit/exporters/onnx.py b/src/setfit/exporters/onnx.py index 51e4b41f..dc27bfa9 100644 --- a/src/setfit/exporters/onnx.py +++ b/src/setfit/exporters/onnx.py @@ -1,6 +1,7 @@ import copy import warnings from typing import Callable, Optional, Union +from inspect import signature import numpy as np import onnx @@ -87,9 +88,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path, for output_name in output_names: dynamic_axes_output[output_name] = {0: "batch_size"} - # Move inputs to the right device + # Move inputs to the right device and put them in the right order + forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered + ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0])) + ordered_params = [param_value for (_, param_value) in ordered_kwargs] target = setfit_model.model_body.device - args = tuple(value.to(target) for value in inputs.values()) + args = tuple(value.to(target) for value in ordered_params) setfit_model.eval() with torch.no_grad():