Skip to content

Commit

Permalink
Update onnx.py
Browse files Browse the repository at this point in the history
Fixed switched token_type_ids and attention_mask and made use of kwargs to make code less error prone in case something about the order of the arguments changes.
  • Loading branch information
rolshoven authored Aug 25, 2023
1 parent 4ebee43 commit 937c408
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def __init__(
self.pooler = pooler
self.model_head = model_head

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
hidden_states = self.model_body(input_ids, attention_mask, token_type_ids)
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor):
hidden_states = self.model_body(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
hidden_states = {"token_embeddings": hidden_states[0], "attention_mask": attention_mask}

embeddings = self.pooler(hidden_states)

# If the model_head is none we are using a sklearn head and only output
Expand Down

0 comments on commit 937c408

Please sign in to comment.