Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr Sink Multiprocessing #1551

Merged
merged 8 commits into from
Jul 16, 2024
Merged
36 changes: 20 additions & 16 deletions sources/zarr/large_image_source_zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import multiprocessing
import os
import shutil
import tempfile
Expand Down Expand Up @@ -110,11 +111,12 @@ def _initNew(self, path, **kwargs):
"""
Initialize the tile class for creating a new image.
"""
self._tempdir = tempfile.TemporaryDirectory(path)
self._zarr_store = zarr.DirectoryStore(self._tempdir.name)
self._zarr = zarr.open(self._zarr_store, mode='w')
# Make unpickleable
self._unpickleable = True
self._tempdir = Path(tempfile.gettempdir(), path)
self._created = False
if not self._tempdir.exists():
self._created = True
self._zarr_store = zarr.DirectoryStore(str(self._tempdir))
self._zarr = zarr.open(self._zarr_store, mode='a')
self._largeImagePath = None
self._dims = {}
self.sizeX = self.sizeY = self.levels = 0
Expand All @@ -123,7 +125,8 @@ def _initNew(self, path, **kwargs):
self._output = None
self._editable = True
self._bandRanges = None
self._addLock = threading.RLock()
self._threadLock = threading.RLock()
self._processLock = multiprocessing.Lock()
self._framecount = 0
self._mm_x = 0
self._mm_y = 0
Expand All @@ -140,7 +143,8 @@ def __del__(self):
except Exception:
pass
try:
self._tempdir.cleanup()
if self._created:
shutil.rmtree(self._tempdir)
except Exception:
pass

Expand Down Expand Up @@ -317,7 +321,7 @@ def _validateZarr(self):
Validate that we can read tiles from the zarr parent group in
self._zarr. Set up the appropriate class variables.
"""
if self._editable:
if self._editable and hasattr(self, '_axes'):
self._writeInternalMetadata()
found = self._scanZarrGroup(self._zarr)
if found['best'] is None:
Expand Down Expand Up @@ -579,7 +583,7 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
}
tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes)

with self._addLock:
with self._threadLock and self._processLock:
self._axes = {k: i for i, k in enumerate(axes)}
new_dims = {
a: max(
Expand Down Expand Up @@ -656,7 +660,7 @@ def addAssociatedImage(self, image, imageKey=None):
with an image. The numpy array can have 2 or 3 dimensions.
"""
data, _ = _imageToNumpy(image)
with self._addLock:
with self._threadLock and self._processLock:
if imageKey is None:
# Each associated image should be in its own group
num_existing = len(self.getAssociatedImagesList())
Expand All @@ -671,7 +675,7 @@ def addAssociatedImage(self, image, imageKey=None):

def _writeInternalMetadata(self):
self._checkEditable()
with self._addLock:
with self._threadLock and self._processLock:
name = str(self._tempdir.name).split('/')[-1]
arrays = dict(self._zarr.arrays())
channel_axis = self._axes.get('c', self._axes.get('s'))
Expand All @@ -683,8 +687,8 @@ def _writeInternalMetadata(self):
for arr_name in arrays:
level = int(arr_name)
scale = [1.0 for a in sorted_axes]
scale[self._axes.get('x')] = self._mm_x * (2 ** level)
scale[self._axes.get('y')] = self._mm_y * (2 ** level)
scale[self._axes.get('x')] = (self._mm_x or 0) * (2 ** level)
scale[self._axes.get('y')] = (self._mm_y or 0) * (2 ** level)
dataset_metadata = {
'path': arr_name,
'coordinateTransformations': [{
Expand Down Expand Up @@ -957,7 +961,7 @@ def write(
**frame_position,
)

source._writeInternalMetadata()
source._validateZarr()

if suffix in ['.zarr', '.db', '.sqlite', '.zip']:
if resample is None:
Expand All @@ -970,7 +974,7 @@ def write(
source._writeInternalMetadata() # rewrite with new level datasets

if suffix == '.zarr':
shutil.copytree(source._tempdir.name, path)
shutil.copytree(str(source._tempdir), path)
elif suffix in ['.db', '.sqlite']:
sqlite_store = zarr.SQLiteStore(path)
zarr.copy_store(source._zarr_store, sqlite_store, if_exists='replace')
Expand All @@ -983,7 +987,7 @@ def write(
else:
from large_image_converter import convert

attrs_path = Path(source._tempdir.name) / '.zattrs'
attrs_path = source._tempdir / '.zattrs'
params = {}
if lossy and self.dtype == np.uint8:
params['compression'] = 'jpeg'
Expand Down
8 changes: 6 additions & 2 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

import large_image_source_vips
import pytest

import large_image
Expand Down Expand Up @@ -61,6 +62,9 @@ def testPickleTile():


def testPickleNew():
ts = large_image.new()
ts_zarr = large_image.new()
pickle.dumps(ts_zarr)

ts_vips = large_image_source_vips.new()
with pytest.raises(pickle.PicklingError):
pickle.dumps(ts)
pickle.dumps(ts_vips)
54 changes: 54 additions & 0 deletions test/test_sink.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from multiprocessing.pool import Pool, ThreadPool

import large_image_source_test
import large_image_source_zarr
Expand Down Expand Up @@ -481,3 +482,56 @@ def testAddAssociatedImages(tmp_path):
assert isinstance(retrieved, Image.Image)
# PIL Image size doesn't include bands and swaps x & y
assert retrieved.size == (expected_size[1], expected_size[0])


def _add_tile_from_seed_data(sink, seed_data, position):
tile = seed_data[
position['z'],
position['y'],
position['x'],
position['s'],
]
sink.addTile(
tile,
position['x'].start,
position['y'].start,
z=position['z'],
)


@pytest.mark.singular()
def testConcurrency(tmp_path):
output_file = tmp_path / 'test.db'
max_workers = 5
tile_size = (100, 100)
target_shape = (4, 1000, 1000, 5)
tile_positions = []
seed_data = np.random.random(target_shape)

for z in range(target_shape[0]):
for y in range(int(target_shape[1] / tile_size[0])):
for x in range(int(target_shape[2] / tile_size[1])):
tile_positions.append({
'z': z,
'y': slice(y * tile_size[0], (y + 1) * tile_size[0]),
'x': slice(x * tile_size[1], (x + 1) * tile_size[1]),
's': slice(0, target_shape[3]),
})

for pool_class in [Pool, ThreadPool]:
sink = large_image_source_zarr.new()
# allocate space by adding last tile first
_add_tile_from_seed_data(sink, seed_data, tile_positions[-1])
with pool_class(max_workers) as pool:
pool.starmap(_add_tile_from_seed_data, [
(sink, seed_data, position)
for position in tile_positions[:-1]
])
sink.write(output_file)
written = large_image_source_zarr.open(output_file)
written_arrays = dict(written._zarr.arrays())
data = np.array(written_arrays.get('0'))
assert len(written_arrays) == written.levels
assert data is not None
assert data.shape == seed_data.shape
assert np.allclose(data, seed_data)