Skip to content

Commit

Permalink
Fixed order of input parameters for onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
rolshoven authored and Luca Rolshoven committed Sep 18, 2023
1 parent 4ebee43 commit 70610f5
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import warnings
from typing import Callable, Optional, Union
from inspect import signature

import numpy as np
import onnx
Expand Down Expand Up @@ -87,9 +88,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
for output_name in output_names:
dynamic_axes_output[output_name] = {0: "batch_size"}

# Move inputs to the right device
# Move inputs to the right device and put them in the right order
forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered
ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0]))
ordered_params = [param_value for (_, param_value) in ordered_kwargs]
target = setfit_model.model_body.device
args = tuple(value.to(target) for value in inputs.values())
args = tuple(value.to(target) for value in ordered_params)

setfit_model.eval()
with torch.no_grad():
Expand Down

0 comments on commit 70610f5

Please sign in to comment.