diff --git a/references/classification/train.py b/references/classification/train.py index eb8b56c1ad0..8b632aed556 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -230,7 +230,7 @@ def main(args): criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) if args.norm_weight_decay is None: - parameters = model.parameters() + parameters = [p for p in model.parameters() if p.requires_grad] else: param_groups = torchvision.ops._utils.split_normalization_params(model) wd_groups = [args.norm_weight_decay, args.weight_decay] diff --git a/references/detection/presets.py b/references/detection/presets.py index eb1fdfc4c42..779f3f218ca 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -3,7 +3,7 @@ class DetectionPresetTrain: - def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): + def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): if data_augmentation == "hflip": self.transforms = T.Compose( [ @@ -12,6 +12,27 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0) T.ConvertImageDtype(torch.float), ] ) + elif data_augmentation == "lsj": + self.transforms = T.Compose( + [ + T.ScaleJitter(target_size=(1024, 1024)), + T.FixedSizeCrop(size=(1024, 1024), fill=mean), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) + elif data_augmentation == "multiscale": + self.transforms = T.Compose( + [ + T.RandomShortestSize( + min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 + ), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) elif data_augmentation == "ssd": self.transforms = T.Compose( [ diff --git a/references/detection/train.py b/references/detection/train.py index b6634061503..758171013e8 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -68,6 +68,7 @@ def get_args_parser(add_help=True): parser.add_argument( "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" ) + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") parser.add_argument( "--lr", default=0.02, @@ -84,6 +85,12 @@ def get_args_parser(add_help=True): help="weight decay (default: 1e-4)", dest="weight_decay", ) + parser.add_argument( + "--norm-weight-decay", + default=None, + type=float, + help="weight decay for Normalization layers (default: None, same value as --wd)", + ) parser.add_argument( "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" ) @@ -176,6 +183,8 @@ def main(args): print("Creating model") kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} + if args.data_augmentation in ["multiscale", "lsj"]: + kwargs["_skip_resize"] = True if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh @@ -191,8 +200,26 @@ def main(args): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module - params = [p for p in model.parameters() if p.requires_grad] - optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + if args.norm_weight_decay is None: + parameters = [p for p in model.parameters() if p.requires_grad] + else: + param_groups = torchvision.ops._utils.split_normalization_params(model) + wd_groups = [args.norm_weight_decay, args.weight_decay] + parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] + + opt_name = args.opt.lower() + if opt_name.startswith("sgd"): + optimizer = torch.optim.SGD( + parameters, + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name, + ) + elif opt_name == "adamw": + optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) + else: + raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.") scaler = torch.cuda.amp.GradScaler() if args.amp else None diff --git a/test/test_extended_models.py b/test/test_extended_models.py index a07b501e15b..577be1d2cd6 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -64,7 +64,6 @@ def test_get_weight(name, weight): ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) - print(weights_enum) assert weights_enum is not None assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 244fdfa4e7d..37e32a830e0 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -187,6 +187,7 @@ def __init__( box_batch_size_per_image=512, box_positive_fraction=0.25, bbox_reg_weights=None, + **kwargs, ): if not hasattr(backbone, "out_channels"): @@ -268,7 +269,7 @@ def __init__( image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] - transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) super().__init__(backbone, rpn, roi_heads, transform) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index bb9dbe65ae7..ae3f4db0cba 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -373,6 +373,7 @@ def __init__( nms_thresh: float = 0.6, detections_per_img: int = 100, topk_candidates: int = 1000, + **kwargs, ): super().__init__() _log_api_usage_once(self) @@ -410,7 +411,7 @@ def __init__( image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] - self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) self.center_sampling_radius = center_sampling_radius self.score_thresh = score_thresh diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index dc03c693e1c..c7df4910009 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -198,6 +198,7 @@ def __init__( keypoint_head=None, keypoint_predictor=None, num_keypoints=None, + **kwargs, ): if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))): @@ -259,6 +260,7 @@ def __init__( box_batch_size_per_image, box_positive_fraction, bbox_reg_weights, + **kwargs, ) self.roi_heads.keypoint_roi_pool = keypoint_roi_pool diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index a6cb731c0df..d46cd721513 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -195,6 +195,7 @@ def __init__( mask_roi_pool=None, mask_head=None, mask_predictor=None, + **kwargs, ): if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))): @@ -254,6 +255,7 @@ def __init__( box_batch_size_per_image, box_positive_fraction, bbox_reg_weights, + **kwargs, ) self.roi_heads.mask_roi_pool = mask_roi_pool diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index c8b0de661f0..910defec80c 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -342,6 +342,7 @@ def __init__( fg_iou_thresh=0.5, bg_iou_thresh=0.4, topk_candidates=1000, + **kwargs, ): super().__init__() _log_api_usage_once(self) @@ -383,7 +384,7 @@ def __init__( image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] - self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) self.score_thresh = score_thresh self.nms_thresh = nms_thresh diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 4f9d7546c2b..537371fdc27 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -195,6 +195,7 @@ def __init__( iou_thresh: float = 0.5, topk_candidates: int = 400, positive_fraction: float = 0.25, + **kwargs: Any, ): super().__init__() _log_api_usage_once(self) @@ -227,7 +228,7 @@ def __init__( if image_std is None: image_std = [0.229, 0.224, 0.225] self.transform = GeneralizedRCNNTransform( - min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size + min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs ) self.score_thresh = score_thresh diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index ac902ac0fd6..4f653a86acd 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,5 +1,5 @@ import math -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, Any import torch import torchvision @@ -91,6 +91,7 @@ def __init__( image_std: List[float], size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None, + **kwargs: Any, ): super().__init__() if not isinstance(min_size, (list, tuple)): @@ -101,6 +102,7 @@ def __init__( self.image_std = image_std self.size_divisible = size_divisible self.fixed_size = fixed_size + self._skip_resize = kwargs.pop("_skip_resize", False) def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None @@ -170,6 +172,8 @@ def resize( ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] if self.training: + if self._skip_resize: + return image, target size = float(self.torch_choice(self.min_size)) else: # FIXME assume for now that testing uses the largest scale diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 107785266a1..8a02490ab13 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -43,7 +43,13 @@ def split_normalization_params( ) -> Tuple[List[Tensor], List[Tensor]]: # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 if not norm_classes: - norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm] + norm_classes = [ + nn.modules.batchnorm._BatchNorm, + nn.LayerNorm, + nn.GroupNorm, + nn.modules.instancenorm._InstanceNorm, + nn.LocalResponseNorm, + ] for t in norm_classes: if not issubclass(t, nn.Module):