-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move deformable detr safe loading code (#1055)
* factor around deformable detr loading/lockfile management for use with deformable table extractor Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * remove unused global variable Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * move .to(device) iniside the lock Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * jitpick Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * set deformable table extractor choose_device detr=True Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * Misc table transformers post-processing (#1077) * misc postprocessing tweaks Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> * typo Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> --------- Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com> --------- Signed-off-by: Henry Lindeman <hmlindeman@yahoo.com>
- Loading branch information
Showing
4 changed files
with
79 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from sycamore.utils.import_utils import requires_modules | ||
from sycamore.utils.time_trace import LogTime | ||
import fasteners | ||
from pathlib import Path | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from transformers import DeformableDetrForObjectDetection | ||
|
||
_DETR_LOCK_FILE = f"{Path.home()}/.cache/Aryn-Detr.lock" | ||
|
||
|
||
@requires_modules("transformers", "local_inference") | ||
def load_deformable_detr(model_name_or_path, device) -> "DeformableDetrForObjectDetection": | ||
"""Load deformable detr without getting concurrency issues in | ||
jitc-ing the deformable attention kernel. | ||
Refactored out of: | ||
https://github.com/aryn-ai/sycamore/blob/7e6b62639ce9b8f63d56cb35a32837d1c97e711e/lib/sycamore/sycamore/transforms/detr_partitioner.py#L686 | ||
""" | ||
from sycamore.utils.pytorch_dir import get_pytorch_build_directory | ||
|
||
with fasteners.InterProcessLock(_DETR_LOCK_FILE): | ||
lockfile = Path(get_pytorch_build_directory("MultiScaleDeformableAttention", False)) / "lock" | ||
lockfile.unlink(missing_ok=True) | ||
|
||
from transformers import DeformableDetrForObjectDetection | ||
|
||
LogTime("loading_model", point=True) | ||
with LogTime("loading_model", log_start=True): | ||
model = DeformableDetrForObjectDetection.from_pretrained(model_name_or_path).to(device) | ||
return model |