Skip to content

Commit

Permalink
Fix soft predictions for Semantic Segmentation (#3934)
Browse files Browse the repository at this point in the history
fix soft preds
kprokofi authored Sep 6, 2024
1 parent 98a9cac commit d8e6454
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
@@ -219,7 +219,8 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:

def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
return self.model(inputs=image, mode="tensor")
raw_outputs = self.model(inputs=image, mode="tensor")
return torch.softmax(raw_outputs, dim=1)

def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
"""Returns a dummy input for semantic segmentation model."""

0 comments on commit d8e6454

Please sign in to comment.