diff --git a/torchbenchmark/models/detectron2_maskrcnn/__init__.py b/torchbenchmark/models/detectron2_maskrcnn/__init__.py index 4c97f306ad..206eeeaa46 100644 --- a/torchbenchmark/models/detectron2_maskrcnn/__init__.py +++ b/torchbenchmark/models/detectron2_maskrcnn/__init__.py @@ -17,6 +17,12 @@ # setup environment variable CURRENT_DIR = Path(os.path.dirname(os.path.realpath(__file__))) DATA_DIR = os.path.join(CURRENT_DIR.parent.parent, "data", ".data", "coco2017-minimal") +if not os.path.exists(DATA_DIR): + try: + from torchbenchmark.util.framework.fb.installer import install_data + DATA_DIR = install_data("coco2017-minimal") + except Exception: + pass assert os.path.exists( DATA_DIR ), "Couldn't find coco2017 minimal data dir, please run install.py again." @@ -79,6 +85,12 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): data_cfg.test.batch_size = self.batch_size self.model = instantiate(model_cfg).to(self.device) # load model from checkpoint + if not os.path.exists(self.model_file): + try: + from torchbenchmark.util.framework.fb.installer import install_model_weights + self.model_file = install_model_weights(self.name) + except Exception: + pass DetectionCheckpointer(self.model).load(self.model_file) self.model.eval() test_loader = instantiate(data_cfg.test) diff --git a/torchbenchmark/util/framework/detectron2/model_factory.py b/torchbenchmark/util/framework/detectron2/model_factory.py index d08d6e7f7d..0eece4bb53 100644 --- a/torchbenchmark/util/framework/detectron2/model_factory.py +++ b/torchbenchmark/util/framework/detectron2/model_factory.py @@ -11,6 +11,12 @@ DATA_DIR = os.path.join( CURRENT_DIR.parent.parent.parent, "data", ".data", "coco2017-minimal" ) +if not os.path.exists(DATA_DIR): + try: + from torchbenchmark.util.framework.fb.installer import install_data + DATA_DIR = install_data("coco2017-minimal") + except Exception: + pass assert os.path.exists( DATA_DIR ), "Couldn't find coco2017 minimal data dir, please run install.py again." @@ -99,6 +105,12 @@ def __init__(self, variant, test, device, batch_size=None, extra_args=[]): assert hasattr( self, "model_file" ), f"Detectron2 models must specify its model_file." + if self.model_file and not os.path.exists(self.model_file): + try: + from torchbenchmark.util.framework.fb.installer import install_model_weights + self.model_file = install_model_weights(self.name) + except Exception: + pass if self.model_file: assert os.path.exists( self.model_file