From 0aacb3c5aee3b52c1fe1542cee55ca3e590bf4c7 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 19 Sep 2022 08:52:04 +0200 Subject: [PATCH] make pretrained backbone flexible in predictor (#1061) --- doctr/models/detection/linknet/pytorch.py | 2 +- doctr/models/detection/zoo.py | 8 +++++++- doctr/models/recognition/zoo.py | 6 +++++- doctr/models/zoo.py | 9 ++++++++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index f01fc3947d..d430a98224 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -241,7 +241,7 @@ def _linknet( pretrained: bool, backbone_fn: Callable[[bool], nn.Module], fpn_layers: List[str], - pretrained_backbone: bool = False, + pretrained_backbone: bool = True, **kwargs: Any, ) -> LinkNet: diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 3578efa573..b76b39fac8 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -46,7 +46,11 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, f"{ROT_ARCHS}" ) - _model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages) + _model = detection.__dict__[arch]( + pretrained=pretrained, + pretrained_backbone=kwargs.get("pretrained_backbone", True), + assume_straight_pages=assume_straight_pages, + ) else: if not isinstance(arch, (detection.DBNet, detection.LinkNet)): raise ValueError(f"unknown architecture: {type(arch)}") @@ -54,6 +58,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, _model = arch _model.assume_straight_pages = assume_straight_pages + kwargs.pop("pretrained_backbone", None) + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 1) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index bf16eb31d6..f387e6ee7c 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -23,12 +23,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict if arch not in ARCHS: raise ValueError(f"unknown architecture '{arch}'") - _model = recognition.__dict__[arch](pretrained=pretrained) + _model = recognition.__dict__[arch]( + pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True) + ) else: if not isinstance(arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch + kwargs.pop("pretrained_backbone", None) + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 32) diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index 3dca466b4f..a9621df4e3 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -16,6 +16,7 @@ def _predictor( det_arch: Any, reco_arch: Any, pretrained: bool, + pretrained_backbone: bool = True, assume_straight_pages: bool = True, preserve_aspect_ratio: bool = False, symmetric_pad: bool = True, @@ -30,6 +31,7 @@ def _predictor( det_predictor = detection_predictor( det_arch, pretrained=pretrained, + pretrained_backbone=pretrained_backbone, batch_size=det_bs, assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, @@ -37,7 +39,9 @@ def _predictor( ) # Recognition - reco_predictor = recognition_predictor(reco_arch, pretrained=pretrained, batch_size=reco_bs) + reco_predictor = recognition_predictor( + reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs + ) return OCRPredictor( det_predictor, @@ -55,6 +59,7 @@ def ocr_predictor( det_arch: Any = "db_resnet50", reco_arch: Any = "crnn_vgg16_bn", pretrained: bool = False, + pretrained_backbone: bool = True, assume_straight_pages: bool = True, preserve_aspect_ratio: bool = False, symmetric_pad: bool = True, @@ -77,6 +82,7 @@ def ocr_predictor( reco_arch: name of the recognition architecture or the model itself to use (e.g. 'crnn_vgg16_bn', 'sar_resnet31') pretrained: If True, returns a model pre-trained on our OCR dataset + pretrained_backbone: If True, returns a model with a pretrained backbone assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before @@ -98,6 +104,7 @@ def ocr_predictor( det_arch, reco_arch, pretrained, + pretrained_backbone=pretrained_backbone, assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, symmetric_pad=symmetric_pad,