diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 76b7c2d1538..064811189ce 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -237,6 +237,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> D inputs.imgs_info, self.num_classes, self.tile_config, + self.explain_mode, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 2a26b688920..3a58ba00715 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -215,6 +215,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntit inputs.imgs_info, self.num_classes, self.tile_config, + self.explain_mode, ) for batch_tile_attrs, batch_tile_input in inputs.unbind(): output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input) diff --git a/src/otx/core/utils/tile_merge.py b/src/otx/core/utils/tile_merge.py index 99149e4e1f8..02457522055 100644 --- a/src/otx/core/utils/tile_merge.py +++ b/src/otx/core/utils/tile_merge.py @@ -27,9 +27,9 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]): Args: img_infos (list[ImageInfo]): Original image information before tiling. - iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.45. - max_num_instances (int, optional): Maximum number of instances to keep. Defaults to 500. - + num_classes (int): Number of classes. + tile_config (TileConfig): Tile configuration. + explain_mode (bool): Whether or not tiles have explain features. Default: False. """ def __init__( @@ -37,6 +37,7 @@ def __init__( img_infos: list[ImageInfo], num_classes: int, tile_config: TileConfig, + explain_mode: bool = False, ) -> None: self.img_infos = img_infos self.num_classes = num_classes @@ -44,6 +45,7 @@ def __init__( self.iou_threshold = tile_config.iou_threshold self.max_num_instances = tile_config.max_num_instances self.with_full_img = tile_config.with_full_img + self.explain_mode = explain_mode @abstractmethod def _merge_entities( @@ -115,7 +117,7 @@ def merge( """ entities_to_merge = defaultdict(list) img_ids = [] - explain_mode = len(batch_tile_preds[0].feature_vector) > 0 + explain_mode = self.explain_mode for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs): batch_size = tile_preds.batch_size @@ -315,7 +317,7 @@ def merge( """ entities_to_merge = defaultdict(list) img_ids = [] - explain_mode = len(batch_tile_preds[0].feature_vector) > 0 + explain_mode = self.explain_mode for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs): feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]