Skip to content

Commit

Permalink
move deformable detr safe loading code (#1055)
Browse files Browse the repository at this point in the history
* factor around deformable detr loading/lockfile management for use with deformable table extractor

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* remove unused global variable

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* move .to(device) iniside the lock

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* jitpick

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* set deformable table extractor choose_device detr=True

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* Misc table transformers post-processing (#1077)

* misc postprocessing tweaks

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

* typo

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

---------

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>

---------

Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>
  • Loading branch information
HenryL27 authored Dec 18, 2024
1 parent e4c213e commit 0f92e7a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 24 deletions.
18 changes: 4 additions & 14 deletions lib/sycamore/sycamore/transforms/detr_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any, BinaryIO, Literal, Union, Optional
from pathlib import Path
from itertools import repeat

import requests
import json
from tenacity import retry, retry_if_exception, wait_exponential, stop_after_delay
import base64
from PIL import Image
import fasteners
from pypdf import PdfReader

from sycamore.data import Element, BoundingBox, ImageElement, TableElement
Expand All @@ -34,7 +32,6 @@
from sycamore.transforms.text_extraction.pdf_miner import PdfMinerExtractor

logger = logging.getLogger(__name__)
_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"
_VERSION = "0.2024.07.24"


Expand Down Expand Up @@ -688,18 +685,11 @@ def __init__(self, model_name_or_path, device=None, cache: Optional[Cache] = Non
self._model_name_or_path = model_name_or_path
self.cache = cache

from sycamore.utils.pytorch_dir import get_pytorch_build_directory
from transformers import AutoImageProcessor
from sycamore.utils.model_load import load_deformable_detr

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import AutoImageProcessor, DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("load_model", log_start=True):
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(self._get_device())
self.processor = AutoImageProcessor.from_pretrained(model_name_or_path)
self.model = load_deformable_detr(model_name_or_path, self._get_device())

# Note: We wrap this in a function so that we can execute on both the leader and the workers
# to account for heterogeneous systems. Currently, if you pass in an explicit device parameter
Expand Down
7 changes: 5 additions & 2 deletions lib/sycamore/sycamore/transforms/table_structure/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def __init__(self, model: str, device=None):
super().__init__(model, device)

def _init_structure_model(self):
from transformers import DeformableDetrForObjectDetection
from sycamore.utils.model_load import load_deformable_detr

self.structure_model = DeformableDetrForObjectDetection.from_pretrained(self.model).to(self._get_device())
self.structure_model = load_deformable_detr(self.model, self._get_device())

def _get_device(self) -> str:
return choose_device(self.device, detr=True)

def extract(
self, element: TableElement, doc_image: Image.Image, union_tokens=False, apply_thresholds=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds
return bboxes, scores, labels


def apply_class_thresholds_or_take_best(bboxes, labels, scores, class_names, class_thresholds, epsilon=0.05):
"""
Filter out bounding boxes whose confidence is below the confidence threshold for its
associated class threshold, defining the threshold as whichever is lower between what
is written in the class_thresholds dict and the highest score for the class minus epsilon
"""
new_class_thresholds = {k: v for k, v in class_thresholds.items()}
max_row_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table row")
max_col_score = max(sc for (sc, lbl) in zip(scores, labels) if class_names[lbl] == "table column")
if max_row_score - epsilon < class_thresholds["table row"]:
new_class_thresholds["table row"] = max_row_score - epsilon
if max_col_score - epsilon < class_thresholds["table column"]:
new_class_thresholds["table column"] = max_col_score - epsilon
new_class_thresholds["table"] = 0.0
return apply_class_thresholds(bboxes, labels, scores, class_names, new_class_thresholds)


def iob(coords1, coords2) -> float:
return BoundingBox(*coords1).iob(BoundingBox(*coords2))

Expand Down Expand Up @@ -83,7 +100,7 @@ def outputs_to_objects(outputs, img_size, id2label, apply_thresholds: bool = Fal
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

if apply_thresholds:
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds(
pred_bboxes, pred_scores, pred_labels = apply_class_thresholds_or_take_best(
pred_bboxes, pred_labels, pred_scores, id2label, DEFAULT_STRUCTURE_CLASS_THRESHOLDS
)

Expand Down Expand Up @@ -287,20 +304,32 @@ def slot_into_containers(
# If the container starts after the package ends, break
if not _early_exit_vertical and container["bbox"][0] > package["bbox"][2]:
if len(match_scores) == 0:
match_scores.append({"container": container, "container_num": container_num, "score": 0})
match_scores.append(
{"container": container, "container_num": container_num, "score": 0, "score_2": 0}
)
break
elif _early_exit_vertical and container["bbox"][1] > package["bbox"][3]:
if len(match_scores) == 0:
match_scores.append({"container": container, "container_num": container_num, "score": 0})
match_scores.append(
{"container": container, "container_num": container_num, "score": 0, "score_2": 0}
)
break
container_rect = BoundingBox(*container["bbox"])
intersect_area = container_rect.intersect(package_rect).area
overlap_fraction = intersect_area / package_area
match_scores.append({"container": container, "container_num": container_num, "score": overlap_fraction})
opposite_overlap_fraction = intersect_area / (container_rect.area or 1)
match_scores.append(
{
"container": container,
"container_num": container_num,
"score": overlap_fraction,
"score_2": opposite_overlap_fraction,
}
)

# Don't sort if you don't have to
if unique_assignment:
sorted_match_scores = [max(match_scores, key=lambda x: x["score"])]
sorted_match_scores = [max(match_scores, key=lambda x: (x["score"], x["score_2"]))]
else:
sorted_match_scores = sort_objects_by_score(match_scores)

Expand Down Expand Up @@ -330,7 +359,7 @@ def sort_objects_by_score(objects, reverse=True):
sign = -1
else:
sign = 1
return sorted(objects, key=lambda k: sign * k["score"])
return sorted(objects, key=lambda k: (sign * k["score"], sign * k.get("score_2", 0)))


def remove_objects_without_content(page_spans, objects):
Expand Down Expand Up @@ -921,10 +950,10 @@ def objects_to_structures(objects, tokens, class_thresholds):
if len(tables) == 0:
return {}
if len(tables) > 1:
tables.sort(key=lambda x: x["score"], reverse=True)
tables.sort(key=lambda x: BoundingBox(*x["bbox"]).area, reverse=True)
import logging

logging.warning("Got multiple tables in document. Using only the highest-scoring one")
logging.warning("Got multiple tables in document. Using only the biggest one")

table = tables[0]
structure = {}
Expand Down
33 changes: 33 additions & 0 deletions lib/sycamore/sycamore/utils/model_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sycamore.utils.import_utils import requires_modules
from sycamore.utils.time_trace import LogTime
import fasteners
from pathlib import Path

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import DeformableDetrForObjectDetection

_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock"


@requires_modules("transformers", "local_inference")
def load_deformable_detr(model_name_or_path, device) -> "DeformableDetrForObjectDetection":
"""Load deformable detr without getting concurrency issues in
jitc-ing the deformable attention kernel.
Refactored out of:
https://github.com/aryn-ai/sycamore/blob/7e6b62639ce9b8f63d56cb35a32837d1c97e711e/lib/sycamore/sycamore/transforms/detr_partitioner.py#L686
"""
from sycamore.utils.pytorch_dir import get_pytorch_build_directory

with fasteners.InterProcessLock(_DETR_LOCK_FILE):
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock"
lockfile.unlink(missing_ok=True)

from transformers import DeformableDetrForObjectDetection

LogTime("loading_model", point=True)
with LogTime("loading_model", log_start=True):
model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(device)
return model

0 comments on commit 0f92e7a

Please sign in to comment.