From 4d09cc3516977dc0327e7869659cd13e8f877379 Mon Sep 17 00:00:00 2001 From: dmitrypuzyrev Date: Wed, 25 Sep 2024 22:51:40 +0200 Subject: [PATCH] Loading model in GUI from torch.hub --- RodTracker/src/RodTracker/ui/detection.py | 69 +++++++++++++++++------ hubconf.py | 9 +-- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/RodTracker/src/RodTracker/ui/detection.py b/RodTracker/src/RodTracker/ui/detection.py index 2845f8f..ced9f4d 100644 --- a/RodTracker/src/RodTracker/ui/detection.py +++ b/RodTracker/src/RodTracker/ui/detection.py @@ -18,6 +18,7 @@ import logging import os +import pathlib import urllib.request from typing import Dict, List @@ -253,18 +254,19 @@ def __init__( self.pb_use_example.clicked.connect(self._use_example_model) def _use_example_model(self): - example_model_file = CONFIG_DIR / "example_model.pt" + model_file_name = "model_cpu.pt" + hub_dir = torch.hub.get_dir() + model_dir = os.path.join(hub_dir, "checkpoints") + example_model_file = pathlib.Path( + os.path.join(model_dir, model_file_name) + ) _logger.info(example_model_file) + example_model_url = ( "https://zenodo.org/records/10255525/files/model_cpu.pt?download=1" ) - if not example_model_file.exists(): - file_MB = int( - urllib.request.urlopen(example_model_url) - .info() - .get("Content-Length") - ) / (1024**2) + if not example_model_file.exists(): msg_confirm_download = QtWidgets.QMessageBox(self.ui) msg_confirm_download.setWindowTitle(APPNAME) msg_confirm_download.setIcon(QtWidgets.QMessageBox.Information) @@ -273,13 +275,10 @@ def _use_example_model(self):

Attempting to download a trained Mask-RCNN model file for detection of rods in the example data. The model is called model_cpu.pt and it will be - downloaded from here:
- - https://zenodo.org/records/10255525

+ downloaded from torch.hub

The file will be downloaded to
- {example_model_file}
- and will occupy ≈{file_MB:.01f} MB.

+ {example_model_file}

""" ) msg_confirm_download.setStandardButtons( @@ -303,10 +302,10 @@ def _use_example_model(self): msg_box.setWindowTitle(APPNAME) worker = pl.Worker( - lambda: torch.hub.download_url_to_file( - example_model_url, - str(example_model_file.resolve()), - progress=False, + lambda: torch.hub.load( + "ANP-Granular/ParticleTracking:develop", + "rods_example_model", + pretrained=True, ) ) worker.signals.result.connect(lambda ret: msg_box.close()) @@ -317,7 +316,34 @@ def _use_example_model(self): self._threads.start(worker) msg_box.exec() else: - self._load_model(str(example_model_file.resolve())) + _logger.info("Attempting to load the example model from cache.") + msg_box = QtWidgets.QMessageBox( + icon=QtWidgets.QMessageBox.Information, + text=( + "Loading the example model file from cache:
" + "{example_model_file}" + "

Please wait until this window closes." + ), + parent=self.ui, + ) + msg_box.setStandardButtons(QtWidgets.QMessageBox.Close) + msg_box.button(QtWidgets.QMessageBox.Close).setEnabled(False) + msg_box.setWindowTitle(APPNAME) + + worker = pl.Worker( + lambda: torch.hub.load( + "ANP-Granular/ParticleTracking:develop", + "rods_example_model", + pretrained=True, + ) + ) + worker.signals.result.connect(lambda ret: msg_box.close()) + worker.signals.result.connect( + lambda ret: self._load_model(str(example_model_file.resolve())) + ) + + self._threads.start(worker) + msg_box.exec() def _use_example_model_from_zenodo(self): example_model_file = CONFIG_DIR / "example_model.pt" @@ -569,6 +595,15 @@ def _load_model(self, file: str): self.model = torch.jit.load(file) self.pb_detect.setEnabled(True) + def _load_example_model_from_hub(self, file: str): + self.le_model.setText(file) + self.model = torch.hub.load( + "ANP-Granular/ParticleTracking:develop", + "rods_example_model", + pretrained=True, + ) + self.pb_detect.setEnabled(True) + def load_model(self): """Show a file selection dialog to a user to select a particle detection model. diff --git a/hubconf.py b/hubconf.py index 9ca1489..3cfddca 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,25 +2,18 @@ dependencies = ["torch", "os"] import torch -import torchvision import os example_model_url = "https://zenodo.org/records/10255525/files/model_cpu.pt?download=1" -model_file_name = "model_file.pt" +model_file_name = "model_cpu.pt" def rods_example_model(pretrained: bool = False, progress: bool = False): - model: torch.ScriptModule = None if pretrained: - # state_dict = torch.hub.load_state_dict_from_url( - # example_model_url, progress=progress - # ) hub_dir = torch.hub.get_dir() model_dir = os.path.join(hub_dir, "checkpoints") os.makedirs(model_dir, exist_ok=True) model_file = os.path.join(model_dir, model_file_name) torch.hub.download_url_to_file(example_model_url, model_file, progress=progress) - model = torch.jit.load(model_file) - return model