Skip to content

Commit

Permalink
Fix yolov3 and vision_maskrcnn in fbcode
Browse files Browse the repository at this point in the history
Summary: As the title says.

Reviewed By: aaronenyeshi

Differential Revision: D57071254

fbshipit-source-id: 502761181997910abe64631678f4e7eab8122555
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed May 8, 2024
1 parent 97cf069 commit d6b44d2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 1 deletion.
6 changes: 6 additions & 0 deletions torchbenchmark/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ def __init__(self, test, device, batch_size=1, extra_args=[]):
# Checkpoint options are here https://github.com/facebookresearch/segment-anything#model-checkpoints
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data")
sam_checkpoint = os.path.join(data_folder, "sam_vit_h_4b8939.pth")
if not os.path.exists(sam_checkpoint):
from torchbenchmark.util.framework.fb.installer import install_model_weights
sam_checkpoint = install_model_weights(self.name)
model_type = "vit_h"

self.model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
self.model.to(device=device)
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data")

image_path = os.path.join(data_folder, "truck.jpg")
if not os.path.exists(image_path):
from torchbenchmark.util.framework.fb.installer import install_data
image_path = os.path.join(install_data("truck"), "truck.jpg")
self.image = cv2.imread(image_path)
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.sample_image = torch.randn((3, 256, 256)).to(device)
Expand Down
3 changes: 3 additions & 0 deletions torchbenchmark/models/torch_multimodal_clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, test, device, batch_size=1, extra_args=[]):
self.data_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), ".data"
)
if not os.path.exists(self.data_folder):
from torchbenchmark.util.framework.fb.installer import install_data
self.data_folder = install_data("pizza")
self.image_name = "pizza.jpg"
self.image = Image.open(os.path.join(self.data_folder, self.image_name))
self.text = ["pizza", "dog"] * 16
Expand Down
6 changes: 6 additions & 0 deletions torchbenchmark/models/vision_maskrcnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))
DATA_DIR = os.path.join(CURRENT_DIR.parent.parent, "data", ".data", "coco2017-minimal")
if not os.path.exists(DATA_DIR):
try:
from torchbenchmark.util.framework.fb.installer import install_data
DATA_DIR = os.path.join(install_data("coco2017-minimal"), "coco2017-minimal")
except Exception:
pass
assert os.path.exists(
DATA_DIR
), "Couldn't find coco2017 minimal data dir, please run install.py again."
Expand Down
8 changes: 7 additions & 1 deletion torchbenchmark/models/yolov3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@

CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__)))
DATA_DIR = os.path.join(CURRENT_DIR.parent.parent, "data", ".data", "coco128")
if not os.path.exists(DATA_DIR):
try:
from torchbenchmark.util.framework.fb.installer import install_data
DATA_DIR = os.path.join(install_data("coco128"), "coco128")
except Exception:
pass
assert os.path.exists(
DATA_DIR
), "Couldn't find coco128 data dir, please run install.py again."
), f"Couldn't find coco128 data dir: {DATA_DIR}, please run install.py again."


class Model(BenchmarkModel):
Expand Down
3 changes: 3 additions & 0 deletions torchbenchmark/models/yolov3/yolo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def get_train(hyp):
data_dict = parse_data_cfg(data)
train_path = os.path.dirname(__file__) + "/" + data_dict["train"]
test_path = os.path.dirname(__file__) + "/" + data_dict["valid"]
if not os.path.exists(train_path):
train_path = os.path.dirname(data) + "/" + "coco128.txt"
test_path = os.path.dirname(data) + "/" + "coco128.txt"
print(train_path)
nc = 1 if opt.single_cls else int(data_dict["classes"]) # number of classes
hyp["cls"] *= nc / 80 # update coco-tuned hyp['cls'] to current dataset
Expand Down

0 comments on commit d6b44d2

Please sign in to comment.