Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/develop' into enhance/det-exp
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Dec 13, 2023
2 parents ba782dc + e88bde1 commit 8d115ef
Show file tree
Hide file tree
Showing 38 changed files with 2,013 additions and 100 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

## \[unreleased\]

### New features

- Add zero-shot visual prompting (https://github.com/openvinotoolkit/training_extensions/pull/2616)

## \[v1.5.0\]

### New features
Expand Down
1 change: 1 addition & 0 deletions src/otx/algorithms/common/configs/training_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TrainType(ConfigurableEnum):
Semisupervised = "Semisupervised"
Selfsupervised = "Selfsupervised"
Incremental = "Incremental"
Zeroshot = "Zeroshot"
Futurework = "Futurework"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .inference import InferenceCallback # noqa: F401
from .inference import InferenceCallback, ZeroShotInferenceCallback # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, List

import numpy as np
import torch
from bson import ObjectId
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
Expand All @@ -25,6 +26,7 @@
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.id import ID
from otx.api.entities.image import Image
from otx.api.entities.label_schema import LabelSchemaEntity
from otx.api.entities.scored_label import ScoredLabel
from otx.api.utils.segmentation_utils import (
create_annotation_from_segmentation_map,
Expand Down Expand Up @@ -94,3 +96,39 @@ def on_predict_epoch_end(self, _trainer: Trainer, _pl_module: LightningModule, o
dataset_item.annotation_scene.append_annotations(annotations)
else:
dataset_item.append_annotations(annotations)


class ZeroShotInferenceCallback(Callback):
"""Callback that updates otx_dataset during zero-shot inference.
Args:
otx_dataset (DatasetEntity): Dataset that predictions will be updated.
label_schema (LabelSchemaEntity): Label schema information.
"""

def __init__(self, otx_dataset: DatasetEntity, label_schema: LabelSchemaEntity):
# TODO (sungchul): consider use_mask
self.otx_dataset = otx_dataset.with_empty_annotations()
self.label_schema = {int(label.id): label for label in label_schema.get_labels(include_empty=True)}

def on_predict_epoch_end(self, _trainer: Trainer, _pl_module: LightningModule, outputs: List[Any]) -> None:
"""Call when the predict epoch ends."""
for batch_output, dataset_item in zip(outputs[0], self.otx_dataset):
# TODO (sungchul): currently, single batch inference is only supported
output = batch_output[0]
annotations: List[Annotation] = []
for label, masks in output.items():
hard_prediction = torch.where(torch.stack(masks, dim=0).sum(dim=0) > 0, 1, 0)
hard_prediction = hard_prediction.numpy()

# TODO (sungchul): consider use_mask
# generate polygon annotations
annotation = create_annotation_from_segmentation_map(
hard_prediction=hard_prediction,
soft_prediction=hard_prediction,
label_map={1: self.label_schema.get(label)},
)
annotations.extend(annotation)

# TODO (sungchul): consider use_mask
dataset_item.append_annotations(annotations)
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ def update_visual_prompting_config(
groups = getattr(otx_config, "groups", None)
if groups:
for group in groups:
if group in ["learning_parameters", "nncf_optimization", "pot_parameters", "postprocessing"]:
if group in [
"learning_parameters",
"nncf_optimization",
"pot_parameters",
"postprocessing",
"algo_backend",
]:
if group in ["nncf_optimization"]:
# TODO (sungchul): Consider nncf_optimization
logger.warning(f"{group} will be implemented.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from otx.algorithms.common.configs.training_base import TrainType
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import (
MultipleInputsCompose,
Pad,
Expand Down Expand Up @@ -129,6 +130,13 @@ def generate_bbox_from_mask(gt_mask: np.ndarray, width: int, height: int) -> Lis
return generate_bbox(x_min, y_min, x_max, y_max, width, height)


def generate_point_from_mask(gt_mask: np.ndarray) -> np.ndarray:
"""Randomly generate point from given mask."""
candidates = np.where(gt_mask == 1)
index = np.random.permutation(len(candidates))[0]
return candidates[index]


class OTXVisualPromptingDataset(Dataset):
"""Visual Prompting Dataset Adaptor.
Expand Down Expand Up @@ -236,6 +244,27 @@ def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:
return item


class OTXZeroShotVisualPromptingDataset(OTXVisualPromptingDataset):
"""Visual Prompting for Zero-shot learning Dataset Adaptor."""

def __getitem__(self, index: int) -> Dict[str, Union[int, List, Tensor]]:
"""Get dataset item.
Args:
index (int): Index of the dataset sample.
Returns:
Dict[str, Union[int, List, Tensor]]: Dataset item.
"""
dataset_item = self.dataset[index]
item: Dict[str, Union[int, Tensor]] = {"index": index, "images": dataset_item.numpy}

prompts = self.get_prompts(dataset_item, self.labels) # , self.generate_point, self.generate_bbox)
item.update({**prompts, "path": dataset_item.media.path})
item = self.transform(item)
return item


class OTXVisualPromptingDataModule(LightningDataModule):
"""Visual Prompting DataModule.
Expand All @@ -244,10 +273,39 @@ class OTXVisualPromptingDataModule(LightningDataModule):
dataset (DatasetEntity): Dataset entity.
"""

def __init__(self, config: Union[DictConfig, ListConfig], dataset: DatasetEntity) -> None:
DATASETS = {
TrainType.Incremental: OTXVisualPromptingDataset,
TrainType.Zeroshot: OTXZeroShotVisualPromptingDataset,
}

def __init__(
self,
config: Union[DictConfig, ListConfig],
dataset: DatasetEntity,
train_type: TrainType = TrainType.Incremental,
) -> None:
super().__init__()
self.config = config
self.dataset = dataset
self.train_type = train_type
# self.kwargs = {}
if self.train_type == TrainType.Zeroshot:
# check zero-shot configs
if self.config.get("train_batch_size", 1) != 1:
logger.warning(
(
f"Zero-shot learning only supports single batch, "
f"update {self.config.get('train_batch_size', 1)} to 1."
)
)
self.config["train_batch_size"] = 1

# self.kwargs.update(
# {
# "generate_point": self.config.get("generate_point", False),
# "generate_bbox": self.config.get("generate_bbox", False),
# }
# )

self.train_otx_dataset: DatasetEntity
self.val_otx_dataset: DatasetEntity
Expand All @@ -267,21 +325,34 @@ def setup(self, stage: Optional[str] = None) -> None:
mean = self.config.normalize.mean
std = self.config.normalize.std
if stage == "fit" or stage is None:
train_otx_dataset = self.dataset.get_subset(Subset.TRAINING)
val_otx_dataset = self.dataset.get_subset(Subset.VALIDATION)

self.train_dataset = OTXVisualPromptingDataset(
train_otx_dataset, image_size, mean, std, offset_bbox=self.config.offset_bbox
self.train_dataset = self.DATASETS[self.train_type](
dataset=self.dataset.get_subset(Subset.TRAINING),
image_size=image_size,
mean=mean,
std=std,
offset_bbox=self.config.offset_bbox,
# **self.kwargs,
)
self.val_dataset = OTXVisualPromptingDataset(val_otx_dataset, image_size, mean, std)

# self.val_dataset = None
if self.train_type == TrainType.Incremental:
self.val_dataset = self.DATASETS[self.train_type](
dataset=self.dataset.get_subset(Subset.VALIDATION), image_size=image_size, mean=mean, std=std
)

if stage == "test":
test_otx_dataset = self.dataset.get_subset(Subset.TESTING)
self.test_dataset = OTXVisualPromptingDataset(test_otx_dataset, image_size, mean, std)
self.test_dataset = self.DATASETS[self.train_type](
dataset=self.dataset.get_subset(Subset.TESTING), image_size=image_size, mean=mean, std=std
)

if stage == "predict":
predict_otx_dataset = self.dataset
self.predict_dataset = OTXVisualPromptingDataset(predict_otx_dataset, image_size, mean, std)
self.predict_dataset = self.DATASETS[self.train_type](
dataset=self.dataset,
image_size=image_size,
mean=mean,
std=std,
# **self.kwargs
)

def summary(self):
"""Print size of the dataset, number of images."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, item: Dict[str, Union[List, Tensor]]) -> Dict[str, Union[List
item["gt_masks"] = [torch.as_tensor(gt_mask) for gt_mask in item["gt_masks"]]
item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"])
if item["points"]:
item["points"] = self.apply_coords(item["points"], item["original_size"])
item["points"] = self.apply_coords(item["points"], item["original_size"], self.target_length)
return item

@classmethod
Expand All @@ -57,21 +57,28 @@ def apply_image(cls, image: np.ndarray, target_length: int) -> np.ndarray:
target_size = cls.get_preprocess_shape(image.shape[0], image.shape[1], target_length)
return np.array(resize(to_pil_image(image), target_size))

def apply_coords(self, coords: np.ndarray, original_size: Union[List[Any], Tensor]) -> np.ndarray:
@classmethod
def apply_coords(
cls, coords: Union[np.ndarray, Tensor], original_size: Union[List[Any], Tensor], target_length: int
) -> np.ndarray:
"""Expects a numpy array of length 2 in the final dimension.
Requires the original image size in (H, W) format.
Args:
coords (np.ndarray): Coordinates array.
coords (Union[np.ndarray, Tensor]): Coordinates array.
original_size (Union[List[Any], Tensor]): Original size of image.
target_length (int): The length of the longest side of the image.
Returns:
np.ndarray: Resized coordinates.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).astype(float)
new_h, new_w = cls.get_preprocess_shape(original_size[0], original_size[1], target_length)
if isinstance(coords, np.ndarray):
coords = deepcopy(coords).astype(np.float32)
else:
coords = deepcopy(coords).to(torch.float32)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
Expand All @@ -86,7 +93,7 @@ def apply_boxes(self, boxes: np.ndarray, original_size: Union[List[Any], Tensor]
Returns:
np.ndarray: Resized boxes.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size, self.target_length)
return boxes.reshape(-1, 4)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .visual_prompters import SegmentAnything # noqa: F401
from .visual_prompters import SegmentAnything, ZeroShotSegmentAnything # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# SPDX-License-Identifier: Apache-2.0

from .segment_anything import SegmentAnything # noqa: F401
from .zero_shot_segment_anything import ZeroShotSegmentAnything # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,8 @@ def predict_step(self, batch, batch_idx) -> Dict[str, Tensor]:

return dict(masks=masks, iou_predictions=iou_predictions, path=batch["path"], labels=batch["labels"])

@staticmethod
def postprocess_masks(
self,
masks: Tensor,
input_size: Tuple[int, int],
padding: Tuple[int, ...],
Expand Down
Loading

0 comments on commit 8d115ef

Please sign in to comment.