Skip to content

Commit

Permalink
Merge pull request #1551 from girder/zarr-sink-concurrency
Browse files Browse the repository at this point in the history
Zarr Sink Multiprocessing
  • Loading branch information
annehaley authored Jul 16, 2024
2 parents cf3a762 + ca6cb93 commit 9f33a36
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 18 deletions.
36 changes: 20 additions & 16 deletions sources/zarr/large_image_source_zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import multiprocessing
import os
import shutil
import tempfile
Expand Down Expand Up @@ -127,11 +128,12 @@ def _initNew(self, path, **kwargs):
"""
Initialize the tile class for creating a new image.
"""
self._tempdir = tempfile.TemporaryDirectory(path)
self._zarr_store = zarr.DirectoryStore(self._tempdir.name)
self._zarr = zarr.open(self._zarr_store, mode='w')
# Make unpickleable
self._unpickleable = True
self._tempdir = Path(tempfile.gettempdir(), path)
self._created = False
if not self._tempdir.exists():
self._created = True
self._zarr_store = zarr.DirectoryStore(str(self._tempdir))
self._zarr = zarr.open(self._zarr_store, mode='a')
self._largeImagePath = None
self._dims = {}
self.sizeX = self.sizeY = self.levels = 0
Expand All @@ -140,7 +142,8 @@ def _initNew(self, path, **kwargs):
self._output = None
self._editable = True
self._bandRanges = None
self._addLock = threading.RLock()
self._threadLock = threading.RLock()
self._processLock = multiprocessing.Lock()
self._framecount = 0
self._mm_x = 0
self._mm_y = 0
Expand All @@ -157,7 +160,8 @@ def __del__(self):
except Exception:
pass
try:
self._tempdir.cleanup()
if self._created:
shutil.rmtree(self._tempdir)
except Exception:
pass

Expand Down Expand Up @@ -334,7 +338,7 @@ def _validateZarr(self):
Validate that we can read tiles from the zarr parent group in
self._zarr. Set up the appropriate class variables.
"""
if self._editable:
if self._editable and hasattr(self, '_axes'):
self._writeInternalMetadata()
found = self._scanZarrGroup(self._zarr)
if found['best'] is None:
Expand Down Expand Up @@ -596,7 +600,7 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
}
tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes)

with self._addLock:
with self._threadLock and self._processLock:
self._axes = {k: i for i, k in enumerate(axes)}
new_dims = {
a: max(
Expand Down Expand Up @@ -673,7 +677,7 @@ def addAssociatedImage(self, image, imageKey=None):
with an image. The numpy array can have 2 or 3 dimensions.
"""
data, _ = _imageToNumpy(image)
with self._addLock:
with self._threadLock and self._processLock:
if imageKey is None:
# Each associated image should be in its own group
num_existing = len(self.getAssociatedImagesList())
Expand All @@ -688,7 +692,7 @@ def addAssociatedImage(self, image, imageKey=None):

def _writeInternalMetadata(self):
self._checkEditable()
with self._addLock:
with self._threadLock and self._processLock:
name = str(self._tempdir.name).split('/')[-1]
arrays = dict(self._zarr.arrays())
channel_axis = self._axes.get('c', self._axes.get('s'))
Expand All @@ -700,8 +704,8 @@ def _writeInternalMetadata(self):
for arr_name in arrays:
level = int(arr_name)
scale = [1.0 for a in sorted_axes]
scale[self._axes.get('x')] = self._mm_x * (2 ** level)
scale[self._axes.get('y')] = self._mm_y * (2 ** level)
scale[self._axes.get('x')] = (self._mm_x or 0) * (2 ** level)
scale[self._axes.get('y')] = (self._mm_y or 0) * (2 ** level)
dataset_metadata = {
'path': arr_name,
'coordinateTransformations': [{
Expand Down Expand Up @@ -974,7 +978,7 @@ def write(
**frame_position,
)

source._writeInternalMetadata()
source._validateZarr()

if suffix in ['.zarr', '.db', '.sqlite', '.zip']:
if resample is None:
Expand All @@ -987,7 +991,7 @@ def write(
source._writeInternalMetadata() # rewrite with new level datasets

if suffix == '.zarr':
shutil.copytree(source._tempdir.name, path)
shutil.copytree(str(source._tempdir), path)
elif suffix in ['.db', '.sqlite']:
sqlite_store = zarr.SQLiteStore(path)
zarr.copy_store(source._zarr_store, sqlite_store, if_exists='replace')
Expand All @@ -1000,7 +1004,7 @@ def write(
else:
from large_image_converter import convert

attrs_path = Path(source._tempdir.name) / '.zattrs'
attrs_path = source._tempdir / '.zattrs'
params = {}
if lossy and self.dtype == np.uint8:
params['compression'] = 'jpeg'
Expand Down
8 changes: 6 additions & 2 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

import large_image_source_vips
import pytest

import large_image
Expand Down Expand Up @@ -61,6 +62,9 @@ def testPickleTile():


def testPickleNew():
ts = large_image.new()
ts_zarr = large_image.new()
pickle.dumps(ts_zarr)

ts_vips = large_image_source_vips.new()
with pytest.raises(pickle.PicklingError):
pickle.dumps(ts)
pickle.dumps(ts_vips)
54 changes: 54 additions & 0 deletions test/test_sink.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from multiprocessing.pool import Pool, ThreadPool

import large_image_source_test
import large_image_source_zarr
Expand Down Expand Up @@ -481,3 +482,56 @@ def testAddAssociatedImages(tmp_path):
assert isinstance(retrieved, Image.Image)
# PIL Image size doesn't include bands and swaps x & y
assert retrieved.size == (expected_size[1], expected_size[0])


def _add_tile_from_seed_data(sink, seed_data, position):
tile = seed_data[
position['z'],
position['y'],
position['x'],
position['s'],
]
sink.addTile(
tile,
position['x'].start,
position['y'].start,
z=position['z'],
)


@pytest.mark.singular()
def testConcurrency(tmp_path):
output_file = tmp_path / 'test.db'
max_workers = 5
tile_size = (100, 100)
target_shape = (4, 1000, 1000, 5)
tile_positions = []
seed_data = np.random.random(target_shape)

for z in range(target_shape[0]):
for y in range(int(target_shape[1] / tile_size[0])):
for x in range(int(target_shape[2] / tile_size[1])):
tile_positions.append({
'z': z,
'y': slice(y * tile_size[0], (y + 1) * tile_size[0]),
'x': slice(x * tile_size[1], (x + 1) * tile_size[1]),
's': slice(0, target_shape[3]),
})

for pool_class in [Pool, ThreadPool]:
sink = large_image_source_zarr.new()
# allocate space by adding last tile first
_add_tile_from_seed_data(sink, seed_data, tile_positions[-1])
with pool_class(max_workers) as pool:
pool.starmap(_add_tile_from_seed_data, [
(sink, seed_data, position)
for position in tile_positions[:-1]
])
sink.write(output_file)
written = large_image_source_zarr.open(output_file)
written_arrays = dict(written._zarr.arrays())
data = np.array(written_arrays.get('0'))
assert len(written_arrays) == written.levels
assert data is not None
assert data.shape == seed_data.shape
assert np.allclose(data, seed_data)

0 comments on commit 9f33a36

Please sign in to comment.