From 9ce2fd0fe288934cecfdaa314f1c1a83255a613f Mon Sep 17 00:00:00 2001 From: hcx1231 <32597852+hcx1231@users.noreply.github.com> Date: Wed, 29 May 2019 03:56:36 +0800 Subject: [PATCH] add vertical flip (#818) * keep the resize function the same in test time the same with training time * add vertical flip --- maskrcnn_benchmark/config/defaults.py | 1 + maskrcnn_benchmark/data/transforms/build.py | 9 ++++++--- maskrcnn_benchmark/data/transforms/transforms.py | 9 +++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index beae4070a..65fbdaddd 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -60,6 +60,7 @@ _C.INPUT.SATURATION = 0.0 _C.INPUT.HUE = 0.0 +_C.INPUT.VERTICAL_FLIP_PROB_TRAIN = 0.0 # ----------------------------------------------------------------------------- # Dataset diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py index 88aa975b6..52385ea7d 100644 --- a/maskrcnn_benchmark/data/transforms/build.py +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -6,7 +6,8 @@ def build_transforms(cfg, is_train=True): if is_train: min_size = cfg.INPUT.MIN_SIZE_TRAIN max_size = cfg.INPUT.MAX_SIZE_TRAIN - flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN + flip_horizontal_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN + flip_vertical_prob = cfg.INPUT.VERTICAL_FLIP_PROB_TRAIN brightness = cfg.INPUT.BRIGHTNESS contrast = cfg.INPUT.CONTRAST saturation = cfg.INPUT.SATURATION @@ -14,7 +15,8 @@ def build_transforms(cfg, is_train=True): else: min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST - flip_prob = 0 + flip_horizontal_prob = 0.0 + flip_vertical_prob = 0.0 brightness = 0.0 contrast = 0.0 saturation = 0.0 @@ -35,7 +37,8 @@ def build_transforms(cfg, is_train=True): [ color_jitter, T.Resize(min_size, max_size), - T.RandomHorizontalFlip(flip_prob), + T.RandomHorizontalFlip(flip_horizontal_prob), + T.RandomVerticalFlip(flip_vertical_prob), T.ToTensor(), normalize_transform, ] diff --git a/maskrcnn_benchmark/data/transforms/transforms.py b/maskrcnn_benchmark/data/transforms/transforms.py index fa1d93934..2d37dc72f 100644 --- a/maskrcnn_benchmark/data/transforms/transforms.py +++ b/maskrcnn_benchmark/data/transforms/transforms.py @@ -73,6 +73,15 @@ def __call__(self, image, target): target = target.transpose(0) return image, target +class RandomVerticalFlip(object): + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + image = F.vflip(image) + target = target.transpose(1) + return image, target class ColorJitter(object): def __init__(self,