-
Notifications
You must be signed in to change notification settings - Fork 6
/
example-run-pytorch-model-in-fiji.py
94 lines (73 loc) · 3.3 KB
/
example-run-pytorch-model-in-fiji.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (C) 2023 Institut Pasteur.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Jython script that downloads the wanted Pytorch model from the Bioimage.io repository,
downloads the engine and executes it on the sample image.
The example model downloaded is:
- Mitochondria resolution enhancement Wasserstein GAN
and can be found at: https://bioimage.io/#/?type=all&tags=Mitochondria%20resolution%20enhancement%20Wasserstein%20GAN&id=10.5281%2Fzenodo.7786492
To run this script with the default parameters:
python example-run-pytorch-model-in-fiji.py
"""
from io.bioimage.modelrunner.engine.installation import EngineInstall
from io.bioimage.modelrunner.bioimageio import BioimageioRepo
from io.bioimage.modelrunner.model import Model
from io.bioimage.modelrunner.tensor import Tensor
from io.bioimage.modelrunner.versionmanagement import AvailableEngines
import sys
import os
from ij import IJ
from net.imglib2.img.display.imagej import ImageJFunctions
from net.imglib2.view import Views
models_path = os.path.join(os. getcwd(), "models")
engine_path = os.path.join(os. getcwd(), "engines")
bmzModelName = "Mitochondria resolution enhancement Wasserstein GAN"
if not os.path.exists(models_path) or not os.path.isdir(models_path):
os.makedirs(models_path)
print("Connecting to the Bioimage.io repository")
br = BioimageioRepo.connect()
print("Downloading the Bioimage.io model: " + bmzModelName)
model_fn = br.downloadByName(bmzModelName, models_path)
print("Model downloaded at: " + model_fn)
print("Download the engine required for the model")
if not os.path.exists(engine_path) or not os.path.isdir(engine_path):
os.makedirs(engine_path)
print("Installing JDLL engine")
supportedList = AvailableEngines.getEnginesForOsByParams("pytorch", "1.13.1", True, None)
gpu = supportedList[0].getGPU()
success = EngineInstall.installEngineWithArgsInDir("pytorch",
"1.13.1", True, gpu, engine_path)
if (success):
print("Engine correctly installed at: " + engine_path)
else:
raise Error("Error with the engine installation.")
imp = IJ.openImage(os.path.join(model_fn, "sample_input_0.tif"))
imp.show()
wrapImg = ImageJFunctions.convertFloat(imp)
wrapImg = Views.addDimension(wrapImg, 0, 0)
wrapImg = Views.addDimension(wrapImg, 0, 0)
wrapImg = Views.permute(wrapImg, 0, 2)
wrapImg = Views.permute(wrapImg, 1, 3)
inputTensor = Tensor.build("input", "bcxy", wrapImg)
outputTensor = Tensor.buildEmptyTensor("output", "bcxy")
model = Model.createBioimageioModel(model_fn, engine_path)
print("Loading model")
model.loadModel()
print("Running model")
model.runModel([inputTensor], [outputTensor])
ImageJFunctions.show( Views.dropSingletonDimensions(outputTensor.getData()) )
print("Display output")
model.closeModel()
inputTensor.close()
outputTensor.close()