Skip to content

Commit

Permalink
fixing bugs with diameter changing during training + denoising (#925)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed May 19, 2024
1 parent 300b5bc commit a29fe7b
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,26 +2178,25 @@ def train_model(self, restore=None, normalize_params=None):
weight_decay=self.training_params["weight_decay"],
n_epochs=self.training_params["n_epochs"],
model_name=self.training_params["model_name"])
diam_labels = self.model.diam_labels #.copy()
# run model on next image
io._add_model(self, self.new_model_path, load_model=False)
io._add_model(self, self.new_model_path)
diam_labels = self.model.net.diam_labels.item() #.copy()
self.new_model_ind = len(self.model_strings)
self.autorun = True
if self.autorun:
channels = self.channels.copy()
self.clear_all()
# keep same channels
self.ChannelChoose[0].setCurrentIndex(channels[0])
self.ChannelChoose[1].setCurrentIndex(channels[1])
self.diameter = diam_labels
self.Diameter.setText("%0.2f" % self.diameter)
self.logger.info(
f">>>> diameter set to diam_labels ( = {diam_labels: 0.3f} )")
self.restore = restore
self.set_normalize_params(normalize_params)
self.get_next_image(load_seg=True)
channels = self.channels.copy()
self.clear_all()
# keep same channels
self.ChannelChoose[0].setCurrentIndex(channels[0])
self.ChannelChoose[1].setCurrentIndex(channels[1])
self.diameter = diam_labels
self.Diameter.setText("%0.2f" % self.diameter)
self.logger.info(
f">>>> diameter set to diam_labels ( = {diam_labels: 0.3f} )")
self.restore = restore
self.set_normalize_params(normalize_params)
self.get_next_image(load_seg=True)

self.compute_segmentation(custom=True)
self.compute_segmentation(custom=True)
self.logger.info(
f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
)
Expand Down Expand Up @@ -2265,7 +2264,7 @@ def compute_cprob(self):

def compute_denoise_model(self, model_type=None):
self.progress.setValue(0)
if 1:
try:
tic = time.time()
nstr = "cyto3" if self.DenoiseChoose.currentText(
) == "one-click" else "nuclei"
Expand Down Expand Up @@ -2314,6 +2313,7 @@ def compute_denoise_model(self, model_type=None):
else:
self.Lyr, self.Lxr = self.Ly, self.Lx
self.Ly0, self.Lx0 = self.Ly, self.Lx
diam_up = self.diameter

img_norm = self.denoise_model.eval(data, channels=channels, z_axis=0,
channel_axis=3, diameter=self.diameter,
Expand Down Expand Up @@ -2392,16 +2392,17 @@ def compute_denoise_model(self, model_type=None):

self.update_plot()

#except Exception as e:
# print("ERROR: %s"%e)
except Exception as e:
print("ERROR: %s"%e)

def compute_segmentation(self, custom=False, model_name=None):
def compute_segmentation(self, custom=False, model_name=None, load_model=True):
self.progress.setValue(0)
if 1:
try:
tic = time.time()
self.clear_all()
self.flows = [[], [], []]
self.initialize_model(model_name=model_name, custom=custom)
if load_model:
self.initialize_model(model_name=model_name, custom=custom)
self.progress.setValue(10)
do_3D = self.load_3D
stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
Expand All @@ -2419,17 +2420,17 @@ def compute_segmentation(self, custom=False, model_name=None):
niter = None if niter == 0 else niter
normalize_params = self.get_normalize_params()
print(normalize_params)
if 1:
try:
masks, flows = self.model.eval(
data, channels=channels, diameter=self.diameter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
normalize=normalize_params, stitch_threshold=stitch_threshold,
progress=self.progress)[:2]
# except Exception as e:
# print("NET ERROR: %s"%e)
# self.progress.setValue(0)
# return
except Exception as e:
print("NET ERROR: %s"%e)
self.progress.setValue(0)
return

self.progress.setValue(75)

Expand Down Expand Up @@ -2478,5 +2479,5 @@ def compute_segmentation(self, custom=False, model_name=None):
self.recompute_masks = True
else:
self.recompute_masks = False
# except Exception as e:
# print("ERROR: %s"%e)
except Exception as e:
print("ERROR: %s"%e)

0 comments on commit a29fe7b

Please sign in to comment.