Skip to content

Commit

Permalink
bug w multi-chan and fix warning from pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 7, 2024
1 parent 79bd598 commit 33d1d48
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
7 changes: 4 additions & 3 deletions cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
filename = model_path(model_type,
0) if pretrained_model is None else pretrained_model
weights = torch.load(filename)
weights = torch.load(filename, weights_only=True)
zp = 0
print(filename)
for name in net1.state_dict():
Expand Down Expand Up @@ -493,11 +493,12 @@ class CellposeDenoiseModel():
""" model to run Cellpose and Image restoration """

def __init__(self, gpu=False, pretrained_model=False, model_type=None,
restore_type="denoise_cyto3", chan2_restore=False, device=None):
restore_type="denoise_cyto3", nchan=2,
chan2_restore=False, device=None):

self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
device=device)
self.cp = CellposeModel(gpu=gpu, model_type=model_type,
self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
pretrained_model=pretrained_model, device=device)

def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
Expand Down
2 changes: 1 addition & 1 deletion cellpose/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
"max_pool": True,
}
cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"))
state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
cpnet_biio.load_state_dict(state_dict_cuda)
cpnet_biio.eval() # crucial for the prediction results
return cpnet_biio, cpnet_kwargs
Expand Down
4 changes: 2 additions & 2 deletions cellpose/resnet_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ def load_model(self, filename, device=None):
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device)
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(self.nbase, self.nout, self.sz, self.mkldnn, self.conv_3D,
self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"))
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)

self.load_state_dict(state_dict)

Expand Down
4 changes: 2 additions & 2 deletions cellpose/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def load_model(self, filename, device=None):
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device)
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(encoder=self.encoder, decoder=self.decoder,
diam_mean=self.diam_mean)
state_dict = torch.load(filename, map_location=torch.device("cpu"))
state_dict = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)

self.load_state_dict(
dict([(name, param) for name, param in state_dict.items()]),
Expand Down

0 comments on commit 33d1d48

Please sign in to comment.