Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move deformable detr safe loading code #1055

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions lib/sycamore/sycamore/transforms/detr_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, BinaryIO, Literal, Union, Optional
from pathlib import Path
from itertools import repeat

import requests
import json
from tenacity import retry, retry_if_exception, wait_exponential, stop_after_delay
import base64
from PIL import Image
import fasteners
from pypdf import PdfReader

from sycamore.data import Element, BoundingBox, ImageElement, TableElement
Expand All @@ -34,7 +32,6 @@
from sycamore.transforms.text_extraction.pdf_miner import PdfMinerExtractor

logger = logging.getLogger(__name__)
_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"
_VERSION = "0.2024.07.24"


Expand Down Expand Up @@ -683,18 +680,11 @@ def __init__(self, model_name_or_path, device=None, cache: Optional[Cache] = Non
self._model_name_or_path = model_name_or_path
self.cache = cache

from sycamore.utils.pytorch_dir import get_pytorch_build_directory
from transformers import AutoImageProcessor
from sycamore.utils.model_load import load_deformable_detr

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import AutoImageProcessor, DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("load_model", log_start=True):
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(self._get_device())
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = load_deformable_detr(model_name_or_path, self._get_device())

# Note: We wrap this in a function so that we can execute on both the leader and the workers
# to account for heterogeneous systems. Currently, if you pass in an explicit device parameter
Expand Down
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def __init__(self, model: str, device=None):
super().__init__(model, device)

def _init_structure_model(self):
from transformers import DeformableDetrForObjectDetection
from sycamore.utils.model_load import load_deformable_detr

self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device())
self.structure_model = load_deformable_detr(self.model, self._get_device())

def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True
Expand Down
33 changes: 33 additions & 0 deletions lib/sycamore/sycamore/utils/model_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sycamore.utils.import_utils import requires_modules
from sycamore.utils.time_trace import LogTime
import fasteners
from pathlib import Path

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import DeformableDetrForObjectDetection

_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"


@requires_modules("transformers", "local_inference")
def load_deformable_detr(model_name_or_path, device) -> "DeformableDetrForObjectDetection":
"""Load deformable detr without getting concurrency issues in
jit-ing the deformable attention kernel.
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved

Refactored out of:
https://github.com/aryn-ai/sycamore/blob/7e6b62639ce9b8f63d56cb35a32837d1c97e711e/lib/sycamore/sycamore/transforms/detr_partitioner.py#L686
"""
from sycamore.utils.pytorch_dir import get_pytorch_build_directory

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("loading_model", log_start=True):
model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(device)
return model
Loading