Skip to content

Commit

Permalink
make pretrained backbone flexible in predictor (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 19, 2022
1 parent 4e763da commit 0aacb3c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _linknet(
pretrained: bool,
backbone_fn: Callable[[bool], nn.Module],
fpn_layers: List[str],
pretrained_backbone: bool = False,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LinkNet:

Expand Down
8 changes: 7 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
f"{ROT_ARCHS}"
)

_model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages)
_model = detection.__dict__[arch](
pretrained=pretrained,
pretrained_backbone=kwargs.get("pretrained_backbone", True),
assume_straight_pages=assume_straight_pages,
)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet)):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
_model.assume_straight_pages = assume_straight_pages

kwargs.pop("pretrained_backbone", None)

kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 1)
Expand Down
6 changes: 5 additions & 1 deletion doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

_model = recognition.__dict__[arch](pretrained=pretrained)
_model = recognition.__dict__[arch](
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
)
else:
if not isinstance(arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

kwargs.pop("pretrained_backbone", None)

kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 32)
Expand Down
9 changes: 8 additions & 1 deletion doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _predictor(
det_arch: Any,
reco_arch: Any,
pretrained: bool,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = True,
Expand All @@ -30,14 +31,17 @@ def _predictor(
det_predictor = detection_predictor(
det_arch,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
batch_size=det_bs,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
)

# Recognition
reco_predictor = recognition_predictor(reco_arch, pretrained=pretrained, batch_size=reco_bs)
reco_predictor = recognition_predictor(
reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs
)

return OCRPredictor(
det_predictor,
Expand All @@ -55,6 +59,7 @@ def ocr_predictor(
det_arch: Any = "db_resnet50",
reco_arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = True,
Expand All @@ -77,6 +82,7 @@ def ocr_predictor(
reco_arch: name of the recognition architecture or the model itself to use
(e.g. 'crnn_vgg16_bn', 'sar_resnet31')
pretrained: If True, returns a model pre-trained on our OCR dataset
pretrained_backbone: If True, returns a model with a pretrained backbone
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
Expand All @@ -98,6 +104,7 @@ def ocr_predictor(
det_arch,
reco_arch,
pretrained,
pretrained_backbone=pretrained_backbone,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
Expand Down

0 comments on commit 0aacb3c

Please sign in to comment.