Skip to content

Commit

Permalink
change type check
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 21, 2024
1 parent 096a3a9 commit 7af1757
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
5 changes: 2 additions & 3 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def _orientation_predictor(
else:
allowed_archs = [classification.MobileNetV3]
if is_torch_available():
# The following is required for torch compiled models
import torch
from doctr.models.utils import _get_torch_compile_type

allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)
allowed_archs.append(_get_torch_compile_type())

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
if is_torch_available():
# The following is required for torch compiled models
import torch
from doctr.models.utils import _get_torch_compile_type

allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)
allowed_archs.append(_get_torch_compile_type())

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
if is_torch_available():
# The following is required for torch compiled models
import torch
from doctr.models.utils import _get_torch_compile_type

allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)
allowed_archs.append(_get_torch_compile_type())

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
Expand Down
5 changes: 5 additions & 0 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
"export_model_to_onnx",
"_copy_tensor",
"_bf16_to_float32",
"_get_torch_compile_type",
]


def _get_torch_compile_type() -> Any:
return torch._dynamo.eval_frame.OptimizedModule


def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
return x.clone().detach()

Expand Down

0 comments on commit 7af1757

Please sign in to comment.