Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make pretrained backbone flexible in predictor #1061

Merged
merged 3 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _linknet(
pretrained: bool,
backbone_fn: Callable[[bool], nn.Module],
fpn_layers: List[str],
pretrained_backbone: bool = False,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LinkNet:

Expand Down
8 changes: 7 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,20 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
f"{ROT_ARCHS}"
)

_model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages)
_model = detection.__dict__[arch](
pretrained=pretrained,
pretrained_backbone=kwargs.get("pretrained_backbone", True),
assume_straight_pages=assume_straight_pages,
)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet)):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
_model.assume_straight_pages = assume_straight_pages

kwargs.pop("pretrained_backbone", None)
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 1)
Expand Down
6 changes: 5 additions & 1 deletion doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

_model = recognition.__dict__[arch](pretrained=pretrained)
_model = recognition.__dict__[arch](
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
)
else:
if not isinstance(arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

kwargs.pop("pretrained_backbone", None)

kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 32)
Expand Down
9 changes: 8 additions & 1 deletion doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _predictor(
det_arch: Any,
reco_arch: Any,
pretrained: bool,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = True,
Expand All @@ -30,14 +31,17 @@ def _predictor(
det_predictor = detection_predictor(
det_arch,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
batch_size=det_bs,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
)

# Recognition
reco_predictor = recognition_predictor(reco_arch, pretrained=pretrained, batch_size=reco_bs)
reco_predictor = recognition_predictor(
reco_arch, pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs
)

return OCRPredictor(
det_predictor,
Expand All @@ -55,6 +59,7 @@ def ocr_predictor(
det_arch: Any = "db_resnet50",
reco_arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = True,
Expand All @@ -77,6 +82,7 @@ def ocr_predictor(
reco_arch: name of the recognition architecture or the model itself to use
(e.g. 'crnn_vgg16_bn', 'sar_resnet31')
pretrained: If True, returns a model pre-trained on our OCR dataset
pretrained_backbone: If True, returns a model with a pretrained backbone
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
Expand All @@ -98,6 +104,7 @@ def ocr_predictor(
det_arch,
reco_arch,
pretrained,
pretrained_backbone=pretrained_backbone,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
Expand Down