diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index a2972d7eb86..85182944474 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -219,7 +219,8 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: """Model forward function used for the model tracing during model exportation.""" - return self.model(inputs=image, mode="tensor") + raw_outputs = self.model(inputs=image, mode="tensor") + return torch.softmax(raw_outputs, dim=1) def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity: """Returns a dummy input for semantic segmentation model."""