Skip to content

Commit

Permalink
Loading model in GUI from torch.hub
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrypuzyrev committed Sep 25, 2024
1 parent bb0e23f commit 4d09cc3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
69 changes: 52 additions & 17 deletions RodTracker/src/RodTracker/ui/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging
import os
import pathlib
import urllib.request
from typing import Dict, List

Expand Down Expand Up @@ -253,18 +254,19 @@ def __init__(
self.pb_use_example.clicked.connect(self._use_example_model)

def _use_example_model(self):
example_model_file = CONFIG_DIR / "example_model.pt"
model_file_name = "model_cpu.pt"
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
example_model_file = pathlib.Path(
os.path.join(model_dir, model_file_name)
)
_logger.info(example_model_file)

example_model_url = (
"https://zenodo.org/records/10255525/files/model_cpu.pt?download=1"
)
if not example_model_file.exists():
file_MB = int(
urllib.request.urlopen(example_model_url)
.info()
.get("Content-Length")
) / (1024**2)

if not example_model_file.exists():
msg_confirm_download = QtWidgets.QMessageBox(self.ui)
msg_confirm_download.setWindowTitle(APPNAME)
msg_confirm_download.setIcon(QtWidgets.QMessageBox.Information)
Expand All @@ -273,13 +275,10 @@ def _use_example_model(self):
<p>Attempting to download a trained Mask-RCNN model file
for detection of rods in the example data.
The model is called <b>model_cpu.pt</b> and it will be
downloaded from here:<br>
<a href="https://zenodo.org/records/10255525">
https://zenodo.org/records/10255525</a> </p>
downloaded from torch.hub </p>
<p>The file will be downloaded to <br>
<b>{example_model_file}</b><br>
and will occupy <b>≈{file_MB:.01f} MB</b>.</p>
<b>{example_model_file}</b> </p>
"""
)
msg_confirm_download.setStandardButtons(
Expand All @@ -303,10 +302,10 @@ def _use_example_model(self):
msg_box.setWindowTitle(APPNAME)

worker = pl.Worker(
lambda: torch.hub.download_url_to_file(
example_model_url,
str(example_model_file.resolve()),
progress=False,
lambda: torch.hub.load(
"ANP-Granular/ParticleTracking:develop",
"rods_example_model",
pretrained=True,
)
)
worker.signals.result.connect(lambda ret: msg_box.close())
Expand All @@ -317,7 +316,34 @@ def _use_example_model(self):
self._threads.start(worker)
msg_box.exec()
else:
self._load_model(str(example_model_file.resolve()))
_logger.info("Attempting to load the example model from cache.")
msg_box = QtWidgets.QMessageBox(
icon=QtWidgets.QMessageBox.Information,
text=(
"Loading the example model file from cache: <br>"
"<b>{example_model_file}</b>"
"<br><br><b>Please wait until this window closes.</b>"
),
parent=self.ui,
)
msg_box.setStandardButtons(QtWidgets.QMessageBox.Close)
msg_box.button(QtWidgets.QMessageBox.Close).setEnabled(False)
msg_box.setWindowTitle(APPNAME)

worker = pl.Worker(
lambda: torch.hub.load(
"ANP-Granular/ParticleTracking:develop",
"rods_example_model",
pretrained=True,
)
)
worker.signals.result.connect(lambda ret: msg_box.close())
worker.signals.result.connect(
lambda ret: self._load_model(str(example_model_file.resolve()))
)

self._threads.start(worker)
msg_box.exec()

def _use_example_model_from_zenodo(self):
example_model_file = CONFIG_DIR / "example_model.pt"
Expand Down Expand Up @@ -569,6 +595,15 @@ def _load_model(self, file: str):
self.model = torch.jit.load(file)
self.pb_detect.setEnabled(True)

def _load_example_model_from_hub(self, file: str):
self.le_model.setText(file)
self.model = torch.hub.load(
"ANP-Granular/ParticleTracking:develop",
"rods_example_model",
pretrained=True,
)
self.pb_detect.setEnabled(True)

def load_model(self):
"""Show a file selection dialog to a user to select a particle
detection model.
Expand Down
9 changes: 1 addition & 8 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,18 @@
dependencies = ["torch", "os"]

import torch
import torchvision
import os

example_model_url = "https://zenodo.org/records/10255525/files/model_cpu.pt?download=1"
model_file_name = "model_file.pt"
model_file_name = "model_cpu.pt"


def rods_example_model(pretrained: bool = False, progress: bool = False):

model: torch.ScriptModule = None
if pretrained:
# state_dict = torch.hub.load_state_dict_from_url(
# example_model_url, progress=progress
# )
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
model_file = os.path.join(model_dir, model_file_name)

torch.hub.download_url_to_file(example_model_url, model_file, progress=progress)
model = torch.jit.load(model_file)
return model

0 comments on commit 4d09cc3

Please sign in to comment.