Skip to content

Commit

Permalink
Enable detectron2 models in Meta internal
Browse files Browse the repository at this point in the history
Summary: As the title says

Reviewed By: aaronenyeshi

Differential Revision: D57054332

fbshipit-source-id: 64a7ab15464b1e8a2da93ece2e9dd0b90ed1d5a3
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed May 7, 2024
1 parent b0a2ff5 commit 97cf069
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions torchbenchmark/models/detectron2_maskrcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
# setup environment variable
CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))
DATA_DIR = os.path.join(CURRENT_DIR.parent.parent, "data", ".data", "coco2017-minimal")
if not os.path.exists(DATA_DIR):
try:
from torchbenchmark.util.framework.fb.installer import install_data
DATA_DIR = install_data("coco2017-minimal")
except Exception:
pass
assert os.path.exists(
DATA_DIR
), "Couldn't find coco2017 minimal data dir, please run install.py again."
Expand Down Expand Up @@ -79,6 +85,12 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
data_cfg.test.batch_size = self.batch_size
self.model = instantiate(model_cfg).to(self.device)
# load model from checkpoint
if not os.path.exists(self.model_file):
try:
from torchbenchmark.util.framework.fb.installer import install_model_weights
self.model_file = install_model_weights(self.name)
except Exception:
pass
DetectionCheckpointer(self.model).load(self.model_file)
self.model.eval()
test_loader = instantiate(data_cfg.test)
Expand Down
12 changes: 12 additions & 0 deletions torchbenchmark/util/framework/detectron2/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
DATA_DIR = os.path.join(
CURRENT_DIR.parent.parent.parent, "data", ".data", "coco2017-minimal"
)
if not os.path.exists(DATA_DIR):
try:
from torchbenchmark.util.framework.fb.installer import install_data
DATA_DIR = install_data("coco2017-minimal")
except Exception:
pass
assert os.path.exists(
DATA_DIR
), "Couldn't find coco2017 minimal data dir, please run install.py again."
Expand Down Expand Up @@ -99,6 +105,12 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]):
assert hasattr(
self, "model_file"
), f"Detectron2 models must specify its model_file."
if self.model_file and not os.path.exists(self.model_file):
try:
from torchbenchmark.util.framework.fb.installer import install_model_weights
self.model_file = install_model_weights(self.name)
except Exception:
pass
if self.model_file:
assert os.path.exists(
self.model_file
Expand Down

0 comments on commit 97cf069

Please sign in to comment.