Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for MergeDims and Split Variables to FilePatternToChunks transform. #39

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions xarray_beam/_src/pangeo_forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,60 @@ def __init__(
pattern: 'FilePattern',
chunks: Optional[Mapping[str, int]] = None,
local_copy: bool = False,
split_vars: bool = False,
num_threads: Optional[int] = None,
xarray_open_kwargs: Optional[Dict] = None
):
"""Initialize FilePatternToChunks.

TODO(#29): Currently, `MergeDim`s are not supported.

Args:
pattern: a `FilePattern` describing a dataset.
chunks: split each open dataset into smaller chunks. If not set, the
transform will return one file per chunk.
local_copy: Open files from the pattern with local copies instead of a
buffered reader.
split_vars: whether to split the dataset into separate records for each
data variables or to keep all data variables together. If the pattern
has merge dimensions (and this flag is false), data will be split
according to the pattern.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
usage and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already loaded in parallel, so this is most
useful for Datasets with a small number of variables.
xarray_open_kwargs: keyword arguments to pass to `xarray.open_dataset()`.
"""
self.pattern = pattern
self.chunks = chunks
self.local_copy = local_copy
self.split_vars = split_vars
self.num_threads = num_threads
self.xarray_open_kwargs = xarray_open_kwargs or {}
self._max_size_idx = {}

if pattern.merge_dims:
raise ValueError("patterns with `MergeDim`s are not supported.")
# cache values so they don't have to be re-computed.
self._max_sizes = {}
self._concat_dims = pattern.concat_dims
self._merge_dims = pattern.merge_dims
self._dim_keys_by_name = {
dim.name: dim.keys for dim in pattern.combine_dims
}

def _maybe_split_vars(
self,
key: core.Key,
dataset: xarray.Dataset
) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
"""If 'split_vars' is enabled, produce a chunk for every variable."""
if not self.split_vars:
yield key, dataset
return

for k in dataset:
yield key.replace(vars={k}), dataset[[k]]

@contextlib.contextmanager
def _open_dataset(self, path: str) -> xarray.Dataset:
"""Open as an XArray Dataset, sometimes with local caching."""
"""Open as an XArray Dataset, optionally with local caching."""
if self.local_copy:
with tempfile.TemporaryDirectory() as tmpdir:
local_file = fsspec.open_local(
Expand All @@ -113,7 +141,7 @@ def _open_dataset(self, path: str) -> xarray.Dataset:
yield xarray.open_dataset(local_file, **self.xarray_open_kwargs)
else:
with FileSystems().open(path) as file:
yield xarray.open_dataset(file, **self.xarray_open_kwargs)
yield xarray.open_dataset(file, **self.xarray_open_kwargs)

def _open_chunks(
self,
Expand All @@ -123,25 +151,39 @@ def _open_chunks(
"""Open datasets into chunks with XArray."""
with self._open_dataset(path) as dataset:

dataset = _expand_dimensions_by_key(dataset, index, self.pattern)
# We only want to expand the concat dimensions of the dataset.
dataset = _expand_dimensions_by_key(
dataset,
tuple((dim for dim in index if dim.name in self._concat_dims)),
alxmrs marked this conversation as resolved.
Show resolved Hide resolved
self.pattern
)

if not self._max_sizes:
self._max_sizes = dataset.sizes
Comment on lines +160 to +161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this before, but why is this caching needed/useful? In general would guess it's probably a bad idea to make stateful Beam transforms, since that breaks one of the underlying assumptions of Beam's data model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's helpful to know that stateful transforms are discouraged; I'll keep that in mind for the future.

This caching was the simplest way I could think of to calculate the correct offsets for the keys. See this code-block for context.

When I calculated the offsets using the current dataset's sizes, it would always fail to compute the last offsets correctly (please see the code-block linked above). The simplest way I could think of to calculate the right starting offset was to cache the first dataset's size, and let the 0-indexed dim.index handle the rest.

From what I can tell, this data is safe to cache. Those are, however, famous last words in parallel programming.


if not self._max_size_idx:
self._max_size_idx = dataset.sizes
variables = {self._dim_keys_by_name[dim.name][dim.index]
for dim in index if dim.name in self._merge_dims}
if not variables:
variables = None

base_key = core.Key(_zero_dimensions(dataset)).with_offsets(
**{dim.name: self._max_size_idx[dim.name] * dim.index for dim in index}
key = core.Key(_zero_dimensions(dataset), variables).with_offsets(
**{dim.name: self._max_sizes[dim.name] * dim.index
for dim in index if dim.name in self._concat_dims}
)

num_threads = len(dataset.data_vars)
num_threads = self.num_threads or len(dataset.data_vars)

# If chunks is not set by the user, treat the dataset as a single chunk.
# If 'chunks' is not set by the user, treat the dataset as a single chunk.
if self.chunks is None:
yield base_key, dataset.compute(num_workers=num_threads)
yield from self._maybe_split_vars(
key, dataset.compute(num_workers=num_threads)
)
return

for new_key, chunk in rechunk.split_chunks(base_key, dataset,
self.chunks):
yield new_key, chunk.compute(num_workers=num_threads)
for new_key, chunk in rechunk.split_chunks(key, dataset, self.chunks):
yield from self._maybe_split_vars(
new_key, chunk.compute(num_workers=num_threads)
)

def expand(self, pcoll):
return (
Expand Down
143 changes: 139 additions & 4 deletions xarray_beam/_src/pangeo_forge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
FilePattern,
ConcatDim,
DimIndex,
CombineOp
CombineOp,
MergeDim,
)

from xarray_beam import split_chunks
Expand Down Expand Up @@ -99,6 +100,32 @@ def make_path(time: int, longitude: int) -> str:
chunk.to_netcdf(make_path(time, long))
yield FilePattern(make_path, time_dim, longitude_dim)

@contextlib.contextmanager
def multivariable_pattern(
self,
time_step: int = 360,
longitude_step: int = 36,
) -> FilePattern:
"""Produces a FilePattern for a test NetCDF data with a merge dimension."""
alxmrs marked this conversation as resolved.
Show resolved Hide resolved
var_names = ['asn', 'd2m', 'e', 'mn2t']
alxmrs marked this conversation as resolved.
Show resolved Hide resolved
time_dim = ConcatDim('time', list(range(0, 360 * 4, time_step)))
longitude_dim = ConcatDim('longitude', list(range(0, 144, longitude_step)))
var_dim = MergeDim('variable', list(var_names))

with tempfile.TemporaryDirectory() as tmpdir:
def make_path(time: int, longitude: int, variable: str) -> str:
return f'{tmpdir}/era5-{time}-{longitude}-{variable}.nc'

for time in time_dim.keys:
for long in longitude_dim.keys:
for var in var_names:
chunk = self.test_data.isel(
time=slice(time, time + time_step),
longitude=slice(long, long + longitude_step)
)[[var]]
chunk.to_netcdf(make_path(time, long, var))
yield FilePattern(make_path, time_dim, longitude_dim, var_dim)

def test_returns_single_dataset(self):
expected = [
(core.Key({"time": 0, "latitude": 0, "longitude": 0}), self.test_data)
Expand All @@ -108,7 +135,20 @@ def test_returns_single_dataset(self):

self.assertAllCloseChunks(actual, expected)

def test_single_subchunks_returns_multiple_datasets(self):
def test_single_dataset_with_split_vars_returns_multiple_datasets(self):
expected = [
(core.Key({"time": 0, "latitude": 0, "longitude": 0}, {'asn'}),
self.test_data[['asn']]),
(core.Key({"time": 0, "latitude": 0, "longitude": 0}, {'d2m'}),
self.test_data[['d2m']]),
]
with self.pattern_from_testdata() as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern,
split_vars=True)

self.assertAllCloseChunks(actual, expected)

def test_single_chunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
Expand All @@ -124,7 +164,26 @@ def test_single_subchunks_returns_multiple_datasets(self):
]
self.assertAllCloseChunks(result, expected)

def test_multiple_subchunks_returns_multiple_datasets(self):
def test_single_chunks_with_splitvars_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
| FilePatternToChunks(pattern,
chunks={"longitude": 48},
split_vars=True)
)

expected = [
(
core.Key({"time": 0, "latitude": 0, "longitude": i}, {var}),
self.test_data.isel(longitude=slice(i, i + 48))[[var]]
)
for i in range(0, 144, 48)
for var in ['asn', 'd2m']
]
self.assertAllCloseChunks(result, expected)

def test_multiple_chunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
Expand Down Expand Up @@ -177,7 +236,7 @@ def test_multiple_datasets_returns_multiple_datasets(
dict(time_step=365, longitude_step=72,
chunks={"longitude": 36, "latitude": 66}),
)
def test_multiple_datasets_with_subchunks_returns_multiple_datasets(
def test_multiple_datasets_with_chunks_returns_multiple_datasets(
self,
time_step: int,
longitude_step: int,
Expand All @@ -203,3 +262,79 @@ def test_multiple_datasets_with_subchunks_returns_multiple_datasets(
)

self.assertAllCloseChunks(actual, expected)

def test_multiple_datasets_with_chunks_and_splitvars_returns_datasets(self):
time_step = 365
longitude_step = 72
chunks = {"latitude": 36}

expected = []
for t, o in itertools.product(range(0, 360 * 4, time_step),
range(0, 144, longitude_step)):

for key, ds in split_chunks(
core.Key({"latitude": 0, "longitude": o, "time": t}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
),
chunks
):
for var in ['asn', 'd2m']:
expected.append((key.replace(vars={var}), ds[[var]]))

with self.multifile_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(
pattern,
chunks=chunks,
split_vars=True,
)

self.assertAllCloseChunks(actual, expected)

def test_multivar_datasets_returns_multiple_datasets(self):
self.test_data = test_util.dummy_era5_surface_dataset(4)
time_step = 360
longitude_step = 36
expected = [
(
core.Key({"time": t, "latitude": 0, "longitude": o}, {var}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
)[[var]]
) for t, o in itertools.product(
range(0, 360 * 4, time_step),
range(0, 144, longitude_step)
) for var in ['asn', 'd2m', 'e', 'mn2t']
]
with self.multivariable_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern)

self.assertAllCloseChunks(actual, expected)

def test_multivar_datasets_with_chunks_returns_multiple_datasets(self):
self.test_data = test_util.dummy_era5_surface_dataset(4)
time_step = 360
longitude_step = 36
chunks = {"latitude": 36, "time": 36}

expected = []
for t, o in itertools.product(range(0, 360 * 4, time_step),
range(0, 144, longitude_step)):
for var in ['asn', 'd2m', 'e', 'mn2t']:
for key, ds in split_chunks(
core.Key({"latitude": 0, "longitude": o, "time": t}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
),
chunks
):
expected.append((key.replace(vars={var}), ds[[var]]))

with self.multivariable_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern,
chunks=chunks)

self.assertAllCloseChunks(actual, expected)