Skip to content

Commit

Permalink
Merge branch 'issue-937' of https://github.com/dpeerlab/cellpose into…
Browse files Browse the repository at this point in the history
… dpeerlab-issue-937
  • Loading branch information
carsen-stringer committed Sep 7, 2024
2 parents 33d1d48 + 1feef06 commit 446a498
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
42 changes: 40 additions & 2 deletions cellpose/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,44 @@ def normalize_img(img, normalize=True, norm3D=False, invert=False, lowhigh=None,

return img_norm

def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):
"""OpenCV resize function does not support uint32.
This function converts the image to float32 before resizing and then converts it back to uint32. Not safe!
References issue: https://github.com/MouseLand/cellpose/issues/937
Implications:
* Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU.
* Memory: However, memory usage increases. Not tested by how much.
Args:
img (ndarray): Image of size [Ly x Lx].
Ly (int): Desired height of the resized image.
Lx (int): Desired width of the resized image.
interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
Returns:
ndarray: Resized image of size [Ly x Lx].
"""

# cast image
cast = img.dtype == np.uint32
if cast:
#
img = img.astype(np.float32)

# resize
img = cv2.resize(img, (Lx, Ly), interpolation=interpolation)

# cast back
if cast:
transforms_logger.warning("resizing image from uint32 to float32 and back to uint32")
img = img.round().astype(np.uint32)

return img


def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR,
no_channels=False):
Expand Down Expand Up @@ -735,10 +773,10 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA
else:
imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), np.float32)
for i, img in enumerate(img0):
imgi = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
imgi = resize_safe(img, Ly, Lx, interpolation=interpolation)
imgs[i] = imgi if imgi.ndim > 2 else imgi[..., np.newaxis]
else:
imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation)
imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation)
return imgs


Expand Down
19 changes: 19 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ def test_normalize_img(data_dir):

img_norm = normalize_img(img, norm3D=False, sharpen_radius=8)
assert img_norm.shape == img.shape

def test_resize(data_dir):
img = io.imread(str(data_dir.joinpath('2D').joinpath('rgb_2D_tif.tif')))

Lx = 100
Ly = 200

img8 = resize_image(img.astype("uint8"), Lx=Lx, Ly=Ly)
assert img8.shape == (Ly, Lx, 3)
assert img8.dtype == np.uint8

img16 = resize_image(img.astype("uint16"), Lx=Lx, Ly=Ly)
assert img16.shape == (Ly, Lx, 3)
assert img16.dtype == np.uint16

img32 = resize_image(img.astype("uint32"), Lx=Lx, Ly=Ly)
assert img32.shape == (Ly, Lx, 3)
assert img32.dtype == np.uint32

0 comments on commit 446a498

Please sign in to comment.