Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Multi-scale testing #804

Merged
merged 3 commits into from
May 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/test_time_aug/e2e_mask_rcnn_R_50_FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
BACKBONE:
CONV_BODY: "R-50-FPN"
RESNETS:
BACKBONE_OUT_CHANNELS: 256
RPN:
USE_FPN: True
ANCHOR_STRIDE: (4, 8, 16, 32, 64)
PRE_NMS_TOP_N_TRAIN: 2000
PRE_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 1000
FPN_POST_NMS_TOP_N_TEST: 1000
ROI_HEADS:
USE_FPN: True
ROI_BOX_HEAD:
POOLER_RESOLUTION: 7
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
POOLER_SAMPLING_RATIO: 2
FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
PREDICTOR: "FPNPredictor"
ROI_MASK_HEAD:
POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
PREDICTOR: "MaskRCNNC4Predictor"
POOLER_RESOLUTION: 14
POOLER_SAMPLING_RATIO: 2
RESOLUTION: 28
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.02
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
TEST:
BBOX_AUG:
ENABLED: True
H_FLIP: True
SCALES: (400, 500, 600, 700, 900, 1000, 1100, 1200)
MAX_SIZE: 2000
SCALE_H_FLIP: True
21 changes: 21 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,27 @@
# Number of detections per image
_C.TEST.DETECTIONS_PER_IMG = 100

# ---------------------------------------------------------------------------- #
# Test-time augmentations for bounding box detection
# See configs/test_time_aug/e2e_mask_rcnn_R-50-FPN_1x.yaml for an example
# ---------------------------------------------------------------------------- #
_C.TEST.BBOX_AUG = CN()

# Enable test-time augmentation for bounding box detection if True
_C.TEST.BBOX_AUG.ENABLED = False

# Horizontal flip at the original scale (id transform)
_C.TEST.BBOX_AUG.H_FLIP = False

# Each scale is the pixel size of an image's shortest side
_C.TEST.BBOX_AUG.SCALES = ()

# Max pixel size of the longer side
_C.TEST.BBOX_AUG.MAX_SIZE = 4000

# Horizontal flip at each scale
_C.TEST.BBOX_AUG.SCALE_H_FLIP = False


# ---------------------------------------------------------------------------- #
# Misc options
Expand Down
8 changes: 5 additions & 3 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import datasets as D
from . import samplers

from .collate_batch import BatchCollator
from .collate_batch import BatchCollator, BBoxAugCollator
from .transforms import build_transforms


Expand Down Expand Up @@ -150,7 +150,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
DatasetCatalog = paths_catalog.DatasetCatalog
dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

transforms = build_transforms(cfg, is_train)
# If bbox aug is enabled in testing, simply set transforms to None and we will apply transforms later
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

data_loaders = []
Expand All @@ -159,7 +160,8 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
batch_sampler = make_batch_data_sampler(
dataset, sampler, aspect_grouping, images_per_gpu, num_iters, start_iter
)
collator = BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
collator = BBoxAugCollator() if not is_train and cfg.TEST.BBOX_AUG.ENABLED else \
BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
12 changes: 12 additions & 0 deletions maskrcnn_benchmark/data/collate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,15 @@ def __call__(self, batch):
targets = transposed_batch[1]
img_ids = transposed_batch[2]
return images, targets, img_ids


class BBoxAugCollator(object):
"""
From a list of samples from the dataset,
returns the images and targets.
Images should be converted to batched images in `im_detect_bbox_aug`
"""

def __call__(self, batch):
return list(zip(*batch))

8 changes: 6 additions & 2 deletions maskrcnn_benchmark/data/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def get_size(self, image_size):

return (oh, ow)

def __call__(self, image, target):
def __call__(self, image, target=None):
size = self.get_size(image.size)
image = F.resize(image, size)
if target is None:
return image
target = target.resize(image.size)
return image, target

Expand Down Expand Up @@ -101,8 +103,10 @@ def __init__(self, mean, std, to_bgr255=True):
self.std = std
self.to_bgr255 = to_bgr255

def __call__(self, image, target):
def __call__(self, image, target=None):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
image = F.normalize(image, mean=self.mean, std=self.std)
if target is None:
return image
return image, target
118 changes: 118 additions & 0 deletions maskrcnn_benchmark/engine/bbox_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
import torchvision.transforms as TT

from maskrcnn_benchmark.config import cfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the config globally in the library is something I've tried to avoid, but let's merge this as is and maybe modify this later on

from maskrcnn_benchmark.data import transforms as T
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.modeling.roi_heads.box_head.inference import make_roi_box_post_processor


def im_detect_bbox_aug(model, images, device):
# Collect detections computed under different transformations
boxlists_ts = []
for _ in range(len(images)):
boxlists_ts.append([])

def add_preds_t(boxlists_t):
for i, boxlist_t in enumerate(boxlists_t):
if len(boxlists_ts[i]) == 0:
# The first one is identity transform, no need to resize the boxlist
boxlists_ts[i].append(boxlist_t)
else:
# Resize the boxlist as the first one
boxlists_ts[i].append(boxlist_t.resize(boxlists_ts[i][0].size))

# Compute detections for the original image (identity transform)
boxlists_i = im_detect_bbox(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_i)

# Perform detection on the horizontally flipped image
if cfg.TEST.BBOX_AUG.H_FLIP:
boxlists_hf = im_detect_bbox_hflip(
model, images, cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST, device
)
add_preds_t(boxlists_hf)

# Compute detections at different scales
for scale in cfg.TEST.BBOX_AUG.SCALES:
max_size = cfg.TEST.BBOX_AUG.MAX_SIZE
boxlists_scl = im_detect_bbox_scale(
model, images, scale, max_size, device
)
add_preds_t(boxlists_scl)

if cfg.TEST.BBOX_AUG.SCALE_H_FLIP:
boxlists_scl_hf = im_detect_bbox_scale(
model, images, scale, max_size, device, hflip=True
)
add_preds_t(boxlists_scl_hf)

# Merge boxlists detected by different bbox aug params
boxlists = []
for i, boxlist_ts in enumerate(boxlists_ts):
bbox = torch.cat([boxlist_t.bbox for boxlist_t in boxlist_ts])
scores = torch.cat([boxlist_t.get_field('scores') for boxlist_t in boxlist_ts])
boxlist = BoxList(bbox, boxlist_ts[0].size, boxlist_ts[0].mode)
boxlist.add_field('scores', scores)
boxlists.append(boxlist)

# Apply NMS and limit the final detections
results = []
post_processor = make_roi_box_post_processor(cfg)
for boxlist in boxlists:
results.append(post_processor.filter_results(boxlist, cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES))

return results


def im_detect_bbox(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the original image.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
return model(images.to(device))


def im_detect_bbox_hflip(model, images, target_scale, target_max_size, device):
"""
Performs bbox detection on the horizontally flipped image.
Function signature is the same as for im_detect_bbox.
"""
transform = TT.Compose([
T.Resize(target_scale, target_max_size),
TT.RandomHorizontalFlip(1.0),
TT.ToTensor(),
T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD, to_bgr255=cfg.INPUT.TO_BGR255
)
])
images = [transform(image) for image in images]
images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
boxlists = model(images.to(device))

# Invert the detections computed on the flipped image
boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists]
return boxlists_inv


def im_detect_bbox_scale(model, images, target_scale, target_max_size, device, hflip=False):
"""
Computes bbox detections at the given scale.
Returns predictions in the scaled image space.
"""
if hflip:
boxlists_scl = im_detect_bbox_hflip(model, images, target_scale, target_max_size, device)
else:
boxlists_scl = im_detect_bbox(model, images, target_scale, target_max_size, device)
return boxlists_scl
8 changes: 6 additions & 2 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import torch
from tqdm import tqdm

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str
from .bbox_aug import im_detect_bbox_aug


def compute_on_dataset(model, data_loader, device, timer=None):
Expand All @@ -19,11 +21,13 @@ def compute_on_dataset(model, data_loader, device, timer=None):
cpu_device = torch.device("cpu")
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if cfg.TEST.BBOX_AUG.ENABLED:
output = im_detect_bbox_aug(model, images, device)
else:
output = model(images.to(device))
if timer:
torch.cuda.synchronize()
timer.toc()
Expand Down
11 changes: 8 additions & 3 deletions maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(
nms=0.5,
detections_per_img=100,
box_coder=None,
cls_agnostic_bbox_reg=False
cls_agnostic_bbox_reg=False,
bbox_aug_enabled=False
):
"""
Arguments:
Expand All @@ -39,6 +40,7 @@ def __init__(
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
self.bbox_aug_enabled = bbox_aug_enabled

def forward(self, x, boxes):
"""
Expand Down Expand Up @@ -79,7 +81,8 @@ def forward(self, x, boxes):
):
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = self.filter_results(boxlist, num_classes)
if not self.bbox_aug_enabled: # If bbox aug is enabled, we will do it later
boxlist = self.filter_results(boxlist, num_classes)
results.append(boxlist)
return results

Expand Down Expand Up @@ -156,12 +159,14 @@ def make_roi_box_post_processor(cfg):
nms_thresh = cfg.MODEL.ROI_HEADS.NMS
detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED

postprocessor = PostProcessor(
score_thresh,
nms_thresh,
detections_per_img,
box_coder,
cls_agnostic_bbox_reg
cls_agnostic_bbox_reg,
bbox_aug_enabled
)
return postprocessor