From 4d09cc3516977dc0327e7869659cd13e8f877379 Mon Sep 17 00:00:00 2001
From: dmitrypuzyrev
Date: Wed, 25 Sep 2024 22:51:40 +0200
Subject: [PATCH] Loading model in GUI from torch.hub
---
RodTracker/src/RodTracker/ui/detection.py | 69 +++++++++++++++++------
hubconf.py | 9 +--
2 files changed, 53 insertions(+), 25 deletions(-)
diff --git a/RodTracker/src/RodTracker/ui/detection.py b/RodTracker/src/RodTracker/ui/detection.py
index 2845f8f..ced9f4d 100644
--- a/RodTracker/src/RodTracker/ui/detection.py
+++ b/RodTracker/src/RodTracker/ui/detection.py
@@ -18,6 +18,7 @@
import logging
import os
+import pathlib
import urllib.request
from typing import Dict, List
@@ -253,18 +254,19 @@ def __init__(
self.pb_use_example.clicked.connect(self._use_example_model)
def _use_example_model(self):
- example_model_file = CONFIG_DIR / "example_model.pt"
+ model_file_name = "model_cpu.pt"
+ hub_dir = torch.hub.get_dir()
+ model_dir = os.path.join(hub_dir, "checkpoints")
+ example_model_file = pathlib.Path(
+ os.path.join(model_dir, model_file_name)
+ )
_logger.info(example_model_file)
+
example_model_url = (
"https://zenodo.org/records/10255525/files/model_cpu.pt?download=1"
)
- if not example_model_file.exists():
- file_MB = int(
- urllib.request.urlopen(example_model_url)
- .info()
- .get("Content-Length")
- ) / (1024**2)
+ if not example_model_file.exists():
msg_confirm_download = QtWidgets.QMessageBox(self.ui)
msg_confirm_download.setWindowTitle(APPNAME)
msg_confirm_download.setIcon(QtWidgets.QMessageBox.Information)
@@ -273,13 +275,10 @@ def _use_example_model(self):
Attempting to download a trained Mask-RCNN model file
for detection of rods in the example data.
The model is called model_cpu.pt and it will be
- downloaded from here:
-
- https://zenodo.org/records/10255525
+ downloaded from torch.hub
The file will be downloaded to
- {example_model_file}
- and will occupy ≈{file_MB:.01f} MB.
+ {example_model_file}
"""
)
msg_confirm_download.setStandardButtons(
@@ -303,10 +302,10 @@ def _use_example_model(self):
msg_box.setWindowTitle(APPNAME)
worker = pl.Worker(
- lambda: torch.hub.download_url_to_file(
- example_model_url,
- str(example_model_file.resolve()),
- progress=False,
+ lambda: torch.hub.load(
+ "ANP-Granular/ParticleTracking:develop",
+ "rods_example_model",
+ pretrained=True,
)
)
worker.signals.result.connect(lambda ret: msg_box.close())
@@ -317,7 +316,34 @@ def _use_example_model(self):
self._threads.start(worker)
msg_box.exec()
else:
- self._load_model(str(example_model_file.resolve()))
+ _logger.info("Attempting to load the example model from cache.")
+ msg_box = QtWidgets.QMessageBox(
+ icon=QtWidgets.QMessageBox.Information,
+ text=(
+ "Loading the example model file from cache:
"
+ "{example_model_file}"
+ "
Please wait until this window closes."
+ ),
+ parent=self.ui,
+ )
+ msg_box.setStandardButtons(QtWidgets.QMessageBox.Close)
+ msg_box.button(QtWidgets.QMessageBox.Close).setEnabled(False)
+ msg_box.setWindowTitle(APPNAME)
+
+ worker = pl.Worker(
+ lambda: torch.hub.load(
+ "ANP-Granular/ParticleTracking:develop",
+ "rods_example_model",
+ pretrained=True,
+ )
+ )
+ worker.signals.result.connect(lambda ret: msg_box.close())
+ worker.signals.result.connect(
+ lambda ret: self._load_model(str(example_model_file.resolve()))
+ )
+
+ self._threads.start(worker)
+ msg_box.exec()
def _use_example_model_from_zenodo(self):
example_model_file = CONFIG_DIR / "example_model.pt"
@@ -569,6 +595,15 @@ def _load_model(self, file: str):
self.model = torch.jit.load(file)
self.pb_detect.setEnabled(True)
+ def _load_example_model_from_hub(self, file: str):
+ self.le_model.setText(file)
+ self.model = torch.hub.load(
+ "ANP-Granular/ParticleTracking:develop",
+ "rods_example_model",
+ pretrained=True,
+ )
+ self.pb_detect.setEnabled(True)
+
def load_model(self):
"""Show a file selection dialog to a user to select a particle
detection model.
diff --git a/hubconf.py b/hubconf.py
index 9ca1489..3cfddca 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -2,25 +2,18 @@
dependencies = ["torch", "os"]
import torch
-import torchvision
import os
example_model_url = "https://zenodo.org/records/10255525/files/model_cpu.pt?download=1"
-model_file_name = "model_file.pt"
+model_file_name = "model_cpu.pt"
def rods_example_model(pretrained: bool = False, progress: bool = False):
- model: torch.ScriptModule = None
if pretrained:
- # state_dict = torch.hub.load_state_dict_from_url(
- # example_model_url, progress=progress
- # )
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
model_file = os.path.join(model_dir, model_file_name)
torch.hub.download_url_to_file(example_model_url, model_file, progress=progress)
- model = torch.jit.load(model_file)
- return model