From 5409ef4ba0e9c1734bf8bdfcf874425bc7e2e970 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Tue, 18 Jun 2024 12:52:49 +0000 Subject: [PATCH 1/5] Add multiprocessing compatibility to zarr sink and add test for behavior --- .../zarr/large_image_source_zarr/__init__.py | 18 +++---- test/test_sink.py | 53 +++++++++++++++++++ 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index b5e12aca3..18f669177 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -3,6 +3,7 @@ import shutil import tempfile import threading +import multiprocessing import uuid import warnings from importlib.metadata import PackageNotFoundError @@ -110,11 +111,9 @@ def _initNew(self, path, **kwargs): """ Initialize the tile class for creating a new image. """ - self._tempdir = tempfile.TemporaryDirectory(path) - self._zarr_store = zarr.DirectoryStore(self._tempdir.name) - self._zarr = zarr.open(self._zarr_store, mode='w') - # Make unpickleable - self._unpickleable = True + self._tempdir = Path(tempfile.gettempdir(), path) + self._zarr_store = zarr.DirectoryStore(str(self._tempdir)) + self._zarr = zarr.open(self._zarr_store, mode='a') self._largeImagePath = None self._dims = {} self.sizeX = self.sizeY = self.levels = 0 @@ -123,7 +122,8 @@ def _initNew(self, path, **kwargs): self._output = None self._editable = True self._bandRanges = None - self._addLock = threading.RLock() + self._threadLock = threading.RLock() + self._processLock = multiprocessing.Lock() self._framecount = 0 self._mm_x = 0 self._mm_y = 0 @@ -579,7 +579,7 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): } tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes) - with self._addLock: + with self._threadLock and self._processLock: self._axes = {k: i for i, k in enumerate(axes)} new_dims = { a: max( @@ -656,7 +656,7 @@ def addAssociatedImage(self, image, imageKey=None): with an image. The numpy array can have 2 or 3 dimensions. """ data, _ = _imageToNumpy(image) - with self._addLock: + with self._threadLock and self._processLock: if imageKey is None: # Each associated image should be in its own group num_existing = len(self.getAssociatedImagesList()) @@ -671,7 +671,7 @@ def addAssociatedImage(self, image, imageKey=None): def _writeInternalMetadata(self): self._checkEditable() - with self._addLock: + with self._threadLock and self._processLock: name = str(self._tempdir.name).split('/')[-1] arrays = dict(self._zarr.arrays()) channel_axis = self._axes.get('s') or self._axes.get('c') diff --git a/test/test_sink.py b/test/test_sink.py index a947f2753..04ca6ee97 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -5,6 +5,7 @@ import numpy as np import pytest from PIL import Image +from multiprocessing.pool import Pool, ThreadPool import large_image from large_image.constants import NEW_IMAGE_PATH_FLAG @@ -468,3 +469,55 @@ def testAddAssociatedImages(tmp_path): assert isinstance(retrieved, Image.Image) # PIL Image size doesn't include bands and swaps x & y assert retrieved.size == (expected_size[1], expected_size[0]) + + +def _add_tile_from_seed_data(sink, seed_data, position): + tile = seed_data[ + position['z'], + position['y'], + position['x'], + position['s'], + ] + sink.addTile( + tile, + position['x'].start, + position['y'].start, + z=position['z'], + ) + + +def testConcurrency(tmp_path): + output_file = tmp_path / 'test.db' + max_workers = 5 + tile_size = (100, 100) + target_shape = (4, 1000, 1000, 5) + tile_positions = [] + seed_data = np.random.random(target_shape) + + for z in range(target_shape[0]): + for y in range(int(target_shape[1] / tile_size[0])): + for x in range(int(target_shape[2] / tile_size[1])): + tile_positions.append({ + 'z': z, + 'y': slice(y * tile_size[0], (y + 1) * tile_size[0]), + 'x': slice(x * tile_size[1], (x + 1) * tile_size[1]), + 's': slice(0, target_shape[3]) + }) + + for pool_class in [Pool, ThreadPool]: + sink = large_image_source_zarr.new() + # allocate space by adding last tile first + _add_tile_from_seed_data(sink, seed_data, tile_positions[-1]) + with pool_class(max_workers) as pool: + pool.starmap(_add_tile_from_seed_data, [ + (sink, seed_data, position) + for position in tile_positions[:-1] + ]) + sink.write(output_file) + written = large_image_source_zarr.open(output_file) + written_arrays = dict(written._zarr.arrays()) + data = np.array(written_arrays.get('0')) + assert len(written_arrays) == written.levels + assert data is not None + assert data.shape == seed_data.shape + assert np.allclose(data, seed_data) From 3e33bfefad3a5b79cf49cbc1f48f3434d5fc9cd1 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Tue, 18 Jun 2024 13:03:08 +0000 Subject: [PATCH 2/5] Reorder imports and add trailing comma --- sources/zarr/large_image_source_zarr/__init__.py | 2 +- test/test_sink.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index 18f669177..dc6f61382 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -1,9 +1,9 @@ import math +import multiprocessing import os import shutil import tempfile import threading -import multiprocessing import uuid import warnings from importlib.metadata import PackageNotFoundError diff --git a/test/test_sink.py b/test/test_sink.py index 04ca6ee97..7c6af5535 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -1,11 +1,11 @@ import math +from multiprocessing.pool import Pool, ThreadPool import large_image_source_test import large_image_source_zarr import numpy as np import pytest from PIL import Image -from multiprocessing.pool import Pool, ThreadPool import large_image from large_image.constants import NEW_IMAGE_PATH_FLAG @@ -501,7 +501,7 @@ def testConcurrency(tmp_path): 'z': z, 'y': slice(y * tile_size[0], (y + 1) * tile_size[0]), 'x': slice(x * tile_size[1], (x + 1) * tile_size[1]), - 's': slice(0, target_shape[3]) + 's': slice(0, target_shape[3]), }) for pool_class in [Pool, ThreadPool]: From 06b18c5ed3bbcd19db6245921397ad996d73f4b9 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Mon, 24 Jun 2024 13:25:14 +0000 Subject: [PATCH 3/5] Fix test behavior --- .../zarr/large_image_source_zarr/__init__.py | 18 +++++++++++------- test/test_pickle.py | 8 ++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index dc6f61382..f3788f9a1 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -112,6 +112,9 @@ def _initNew(self, path, **kwargs): Initialize the tile class for creating a new image. """ self._tempdir = Path(tempfile.gettempdir(), path) + self._created = False + if not self._tempdir.exists(): + self._created = True self._zarr_store = zarr.DirectoryStore(str(self._tempdir)) self._zarr = zarr.open(self._zarr_store, mode='a') self._largeImagePath = None @@ -140,7 +143,8 @@ def __del__(self): except Exception: pass try: - self._tempdir.cleanup() + if self._created: + shutil.rmtree(self._tempdir) except Exception: pass @@ -317,7 +321,7 @@ def _validateZarr(self): Validate that we can read tiles from the zarr parent group in self._zarr. Set up the appropriate class variables. """ - if self._editable: + if self._editable and hasattr(self, '_axes'): self._writeInternalMetadata() found = self._scanZarrGroup(self._zarr) if found['best'] is None: @@ -683,8 +687,8 @@ def _writeInternalMetadata(self): for arr_name in arrays: level = int(arr_name) scale = [1.0 for a in sorted_axes] - scale[self._axes.get('x')] = self._mm_x * (2 ** level) - scale[self._axes.get('y')] = self._mm_y * (2 ** level) + scale[self._axes.get('x')] = (self._mm_x or 0) * (2 ** level) + scale[self._axes.get('y')] = (self._mm_y or 0) * (2 ** level) dataset_metadata = { 'path': arr_name, 'coordinateTransformations': [{ @@ -951,7 +955,7 @@ def write( **frame_position, ) - source._writeInternalMetadata() + source._validateZarr() if suffix in ['.zarr', '.db', '.sqlite', '.zip']: if resample is None: @@ -964,7 +968,7 @@ def write( source._writeInternalMetadata() # rewrite with new level datasets if suffix == '.zarr': - shutil.copytree(source._tempdir.name, path) + shutil.copytree(str(source._tempdir), path) elif suffix in ['.db', '.sqlite']: sqlite_store = zarr.SQLiteStore(path) zarr.copy_store(source._zarr_store, sqlite_store, if_exists='replace') @@ -977,7 +981,7 @@ def write( else: from large_image_converter import convert - attrs_path = Path(source._tempdir.name) / '.zattrs' + attrs_path = source._tempdir / '.zattrs' params = {} if lossy and self.dtype == np.uint8: params['compression'] = 'jpeg' diff --git a/test/test_pickle.py b/test/test_pickle.py index 670ea761e..c2f8ce7c2 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -3,6 +3,7 @@ import pytest import large_image +import large_image_source_vips from .datastore import datastore @@ -61,6 +62,9 @@ def testPickleTile(): def testPickleNew(): - ts = large_image.new() + ts_zarr = large_image.new() + pickle.dumps(ts_zarr) + + ts_vips = large_image_source_vips.new() with pytest.raises(pickle.PicklingError): - pickle.dumps(ts) + pickle.dumps(ts_vips) From 41f7b2eb26410f01a24b3f4179f4e356c0e21ccf Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Mon, 24 Jun 2024 13:36:26 +0000 Subject: [PATCH 4/5] Reorder imports in test_pickle.py --- test/test_pickle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_pickle.py b/test/test_pickle.py index c2f8ce7c2..aede88914 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,9 +1,9 @@ import pickle +import large_image_source_vips import pytest import large_image -import large_image_source_vips from .datastore import datastore From c2e1483059568531f49af7ab8c955d6a90d980f1 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Tue, 9 Jul 2024 17:23:48 +0000 Subject: [PATCH 5/5] fix: Add singular decorator to concurrency test --- test/test_sink.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_sink.py b/test/test_sink.py index 7c6af5535..f6c948d35 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -486,6 +486,7 @@ def _add_tile_from_seed_data(sink, seed_data, position): ) +@pytest.mark.singular() def testConcurrency(tmp_path): output_file = tmp_path / 'test.db' max_workers = 5