diff --git a/sources/zarr/large_image_source_zarr/__init__.py b/sources/zarr/large_image_source_zarr/__init__.py index aedce1808..c48490b7b 100644 --- a/sources/zarr/large_image_source_zarr/__init__.py +++ b/sources/zarr/large_image_source_zarr/__init__.py @@ -1,4 +1,5 @@ import math +import multiprocessing import os import shutil import tempfile @@ -110,11 +111,12 @@ def _initNew(self, path, **kwargs): """ Initialize the tile class for creating a new image. """ - self._tempdir = tempfile.TemporaryDirectory(path) - self._zarr_store = zarr.DirectoryStore(self._tempdir.name) - self._zarr = zarr.open(self._zarr_store, mode='w') - # Make unpickleable - self._unpickleable = True + self._tempdir = Path(tempfile.gettempdir(), path) + self._created = False + if not self._tempdir.exists(): + self._created = True + self._zarr_store = zarr.DirectoryStore(str(self._tempdir)) + self._zarr = zarr.open(self._zarr_store, mode='a') self._largeImagePath = None self._dims = {} self.sizeX = self.sizeY = self.levels = 0 @@ -123,7 +125,8 @@ def _initNew(self, path, **kwargs): self._output = None self._editable = True self._bandRanges = None - self._addLock = threading.RLock() + self._threadLock = threading.RLock() + self._processLock = multiprocessing.Lock() self._framecount = 0 self._mm_x = 0 self._mm_y = 0 @@ -140,7 +143,8 @@ def __del__(self): except Exception: pass try: - self._tempdir.cleanup() + if self._created: + shutil.rmtree(self._tempdir) except Exception: pass @@ -317,7 +321,7 @@ def _validateZarr(self): Validate that we can read tiles from the zarr parent group in self._zarr. Set up the appropriate class variables. """ - if self._editable: + if self._editable and hasattr(self, '_axes'): self._writeInternalMetadata() found = self._scanZarrGroup(self._zarr) if found['best'] is None: @@ -579,7 +583,7 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs): } tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes) - with self._addLock: + with self._threadLock and self._processLock: self._axes = {k: i for i, k in enumerate(axes)} new_dims = { a: max( @@ -656,7 +660,7 @@ def addAssociatedImage(self, image, imageKey=None): with an image. The numpy array can have 2 or 3 dimensions. """ data, _ = _imageToNumpy(image) - with self._addLock: + with self._threadLock and self._processLock: if imageKey is None: # Each associated image should be in its own group num_existing = len(self.getAssociatedImagesList()) @@ -671,7 +675,7 @@ def addAssociatedImage(self, image, imageKey=None): def _writeInternalMetadata(self): self._checkEditable() - with self._addLock: + with self._threadLock and self._processLock: name = str(self._tempdir.name).split('/')[-1] arrays = dict(self._zarr.arrays()) channel_axis = self._axes.get('c', self._axes.get('s')) @@ -683,8 +687,8 @@ def _writeInternalMetadata(self): for arr_name in arrays: level = int(arr_name) scale = [1.0 for a in sorted_axes] - scale[self._axes.get('x')] = self._mm_x * (2 ** level) - scale[self._axes.get('y')] = self._mm_y * (2 ** level) + scale[self._axes.get('x')] = (self._mm_x or 0) * (2 ** level) + scale[self._axes.get('y')] = (self._mm_y or 0) * (2 ** level) dataset_metadata = { 'path': arr_name, 'coordinateTransformations': [{ @@ -957,7 +961,7 @@ def write( **frame_position, ) - source._writeInternalMetadata() + source._validateZarr() if suffix in ['.zarr', '.db', '.sqlite', '.zip']: if resample is None: @@ -970,7 +974,7 @@ def write( source._writeInternalMetadata() # rewrite with new level datasets if suffix == '.zarr': - shutil.copytree(source._tempdir.name, path) + shutil.copytree(str(source._tempdir), path) elif suffix in ['.db', '.sqlite']: sqlite_store = zarr.SQLiteStore(path) zarr.copy_store(source._zarr_store, sqlite_store, if_exists='replace') @@ -983,7 +987,7 @@ def write( else: from large_image_converter import convert - attrs_path = Path(source._tempdir.name) / '.zattrs' + attrs_path = source._tempdir / '.zattrs' params = {} if lossy and self.dtype == np.uint8: params['compression'] = 'jpeg' diff --git a/test/test_pickle.py b/test/test_pickle.py index 670ea761e..aede88914 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,5 +1,6 @@ import pickle +import large_image_source_vips import pytest import large_image @@ -61,6 +62,9 @@ def testPickleTile(): def testPickleNew(): - ts = large_image.new() + ts_zarr = large_image.new() + pickle.dumps(ts_zarr) + + ts_vips = large_image_source_vips.new() with pytest.raises(pickle.PicklingError): - pickle.dumps(ts) + pickle.dumps(ts_vips) diff --git a/test/test_sink.py b/test/test_sink.py index 1679321c5..1b59f8f31 100644 --- a/test/test_sink.py +++ b/test/test_sink.py @@ -1,4 +1,5 @@ import math +from multiprocessing.pool import Pool, ThreadPool import large_image_source_test import large_image_source_zarr @@ -481,3 +482,56 @@ def testAddAssociatedImages(tmp_path): assert isinstance(retrieved, Image.Image) # PIL Image size doesn't include bands and swaps x & y assert retrieved.size == (expected_size[1], expected_size[0]) + + +def _add_tile_from_seed_data(sink, seed_data, position): + tile = seed_data[ + position['z'], + position['y'], + position['x'], + position['s'], + ] + sink.addTile( + tile, + position['x'].start, + position['y'].start, + z=position['z'], + ) + + +@pytest.mark.singular() +def testConcurrency(tmp_path): + output_file = tmp_path / 'test.db' + max_workers = 5 + tile_size = (100, 100) + target_shape = (4, 1000, 1000, 5) + tile_positions = [] + seed_data = np.random.random(target_shape) + + for z in range(target_shape[0]): + for y in range(int(target_shape[1] / tile_size[0])): + for x in range(int(target_shape[2] / tile_size[1])): + tile_positions.append({ + 'z': z, + 'y': slice(y * tile_size[0], (y + 1) * tile_size[0]), + 'x': slice(x * tile_size[1], (x + 1) * tile_size[1]), + 's': slice(0, target_shape[3]), + }) + + for pool_class in [Pool, ThreadPool]: + sink = large_image_source_zarr.new() + # allocate space by adding last tile first + _add_tile_from_seed_data(sink, seed_data, tile_positions[-1]) + with pool_class(max_workers) as pool: + pool.starmap(_add_tile_from_seed_data, [ + (sink, seed_data, position) + for position in tile_positions[:-1] + ]) + sink.write(output_file) + written = large_image_source_zarr.open(output_file) + written_arrays = dict(written._zarr.arrays()) + data = np.array(written_arrays.get('0')) + assert len(written_arrays) == written.levels + assert data is not None + assert data.shape == seed_data.shape + assert np.allclose(data, seed_data)