diff --git a/setup.py b/setup.py index baaf8d7..663f0b3 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,9 @@ 'absl-py', 'pandas', 'pytest', + 'pangeo-forge-recipes', + 'scipy', + 'h5netcdf' ] setuptools.setup( diff --git a/xarray_beam/_src/pangeo_forge.py b/xarray_beam/_src/pangeo_forge.py new file mode 100644 index 0000000..d065a4a --- /dev/null +++ b/xarray_beam/_src/pangeo_forge.py @@ -0,0 +1,151 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""IO with Pangeo-Forge.""" +import contextlib +import tempfile +from typing import ( + Dict, + Iterator, + Optional, + Mapping, + Tuple, +) + +import apache_beam as beam +import fsspec +import xarray +from apache_beam.io.filesystems import FileSystems + +from xarray_beam._src import core, rechunk + + +def _zero_dimensions(dataset: xarray.Dataset) -> Mapping[str, int]: + return {dim: 0 for dim in dataset.dims.keys()} + + +def _expand_dimensions_by_key( + dataset: xarray.Dataset, + index: 'FilePatternIndex', + pattern: 'FilePattern' +) -> xarray.Dataset: + """Expand the dimensions of the `Dataset` by offsets found in the `Key`.""" + combine_dims_by_name = { + combine_dim.name: combine_dim for combine_dim in pattern.combine_dims + } + index_by_name = { + idx.name: idx for idx in index + } + + if not combine_dims_by_name: + return dataset + + for dim_key in index_by_name.keys(): + # skip expanding dimensions if they already exist + if dim_key in dataset.dims: + continue + + try: + combine_dim = combine_dims_by_name[dim_key] + except KeyError: + raise ValueError( + f"could not find CombineDim named {dim_key!r} in pattern {pattern!r}." + ) + + dim_val = combine_dim.keys[index_by_name[dim_key].index] + dataset = dataset.expand_dims(**{dim_key: [dim_val]}) + + return dataset + + +class FilePatternToChunks(beam.PTransform): + """Open data described by a Pangeo-Forge `FilePattern` into keyed chunks.""" + + from pangeo_forge_recipes.patterns import FilePattern, FilePatternIndex + + def __init__( + self, + pattern: 'FilePattern', + chunks: Optional[Mapping[str, int]] = None, + local_copy: bool = False, + xarray_open_kwargs: Optional[Dict] = None + ): + """Initialize FilePatternToChunks. + + TODO(#29): Currently, `MergeDim`s are not supported. + + Args: + pattern: a `FilePattern` describing a dataset. + chunks: split each open dataset into smaller chunks. If not set, the + transform will return one file per chunk. + local_copy: Open files from the pattern with local copies instead of a + buffered reader. + xarray_open_kwargs: keyword arguments to pass to `xarray.open_dataset()`. + """ + self.pattern = pattern + self.chunks = chunks + self.local_copy = local_copy + self.xarray_open_kwargs = xarray_open_kwargs or {} + self._max_size_idx = {} + + if pattern.merge_dims: + raise ValueError("patterns with `MergeDim`s are not supported.") + + @contextlib.contextmanager + def _open_dataset(self, path: str) -> xarray.Dataset: + """Open as an XArray Dataset, sometimes with local caching.""" + if self.local_copy: + with tempfile.TemporaryDirectory() as tmpdir: + local_file = fsspec.open_local( + f"simplecache::{path}", + simplecache={'cache_storage': tmpdir} + ) + yield xarray.open_dataset(local_file, **self.xarray_open_kwargs) + else: + with FileSystems().open(path) as file: + yield xarray.open_dataset(file, **self.xarray_open_kwargs) + + def _open_chunks( + self, + index: 'FilePatternIndex', + path: str + ) -> Iterator[Tuple[core.Key, xarray.Dataset]]: + """Open datasets into chunks with XArray.""" + with self._open_dataset(path) as dataset: + + dataset = _expand_dimensions_by_key(dataset, index, self.pattern) + + if not self._max_size_idx: + self._max_size_idx = dataset.sizes + + base_key = core.Key(_zero_dimensions(dataset)).with_offsets( + **{dim.name: self._max_size_idx[dim.name] * dim.index for dim in index} + ) + + num_threads = len(dataset.data_vars) + + # If chunks is not set by the user, treat the dataset as a single chunk. + if self.chunks is None: + yield base_key, dataset.compute(num_workers=num_threads) + return + + for new_key, chunk in rechunk.split_chunks(base_key, dataset, + self.chunks): + yield new_key, chunk.compute(num_workers=num_threads) + + def expand(self, pcoll): + return ( + pcoll + | beam.Create(list(self.pattern.items())) + | beam.FlatMapTuple(self._open_chunks) + ) diff --git a/xarray_beam/_src/pangeo_forge_test.py b/xarray_beam/_src/pangeo_forge_test.py new file mode 100644 index 0000000..eb3b7d4 --- /dev/null +++ b/xarray_beam/_src/pangeo_forge_test.py @@ -0,0 +1,205 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for xarray_beam._src.pangeo.""" + +import contextlib +import itertools +import tempfile +from typing import Dict + +import numpy as np +from absl.testing import parameterized +from pangeo_forge_recipes.patterns import ( + FilePattern, + ConcatDim, + DimIndex, + CombineOp +) + +from xarray_beam import split_chunks +from xarray_beam._src import core +from xarray_beam._src import test_util +from xarray_beam._src.pangeo_forge import ( + FilePatternToChunks, + _expand_dimensions_by_key +) + + +class ExpandDimensionsByKeyTest(test_util.TestCase): + + def setUp(self): + self.test_data = test_util.dummy_era5_surface_dataset() + self.level = ConcatDim("level", list(range(91, 100))) + self.pattern = FilePattern(lambda level: f"gs://dir/{level}.nc", self.level) + + def test_expands_dimensions(self): + for i, (index, _) in enumerate(self.pattern.items()): + actual = _expand_dimensions_by_key( + self.test_data, index, self.pattern + ) + + expected_dims = dict(self.test_data.dims) + expected_dims.update({"level": 1}) + + self.assertEqual(expected_dims, dict(actual.dims)) + self.assertEqual(np.array([self.level.keys[i]]), actual["level"]) + + def test_raises_error_when_dataset_is_not_found(self): + index = (DimIndex('boat', 0, 1, CombineOp.CONCAT),) + with self.assertRaisesRegex(ValueError, "boat"): + _expand_dimensions_by_key( + self.test_data, index, self.pattern + ) + + +class FilePatternToChunksTest(test_util.TestCase): + + def setUp(self): + self.test_data = test_util.dummy_era5_surface_dataset() + + @contextlib.contextmanager + def pattern_from_testdata(self) -> FilePattern: + """Produces a FilePattern for a temporary NetCDF file of test data.""" + with tempfile.TemporaryDirectory() as tmpdir: + target = f'{tmpdir}/era5.nc' + self.test_data.to_netcdf(target) + yield FilePattern(lambda: target) + + @contextlib.contextmanager + def multifile_pattern( + self, + time_step: int = 479, + longitude_step: int = 47 + ) -> FilePattern: + """Produces a FilePattern for a temporary NetCDF file of test data.""" + time_dim = ConcatDim('time', list(range(0, 360 * 4, time_step))) + longitude_dim = ConcatDim('longitude', list(range(0, 144, longitude_step))) + + with tempfile.TemporaryDirectory() as tmpdir: + def make_path(time: int, longitude: int) -> str: + return f'{tmpdir}/era5-{time}-{longitude}.nc' + + for time in time_dim.keys: + for long in longitude_dim.keys: + chunk = self.test_data.isel( + time=slice(time, time + time_step), + longitude=slice(long, long + longitude_step) + ) + chunk.to_netcdf(make_path(time, long)) + yield FilePattern(make_path, time_dim, longitude_dim) + + def test_returns_single_dataset(self): + expected = [ + (core.Key({"time": 0, "latitude": 0, "longitude": 0}), self.test_data) + ] + with self.pattern_from_testdata() as pattern: + actual = test_util.EagerPipeline() | FilePatternToChunks(pattern) + + self.assertAllCloseChunks(actual, expected) + + def test_single_subchunks_returns_multiple_datasets(self): + with self.pattern_from_testdata() as pattern: + result = ( + test_util.EagerPipeline() + | FilePatternToChunks(pattern, chunks={"longitude": 48}) + ) + + expected = [ + ( + core.Key({"time": 0, "latitude": 0, "longitude": i}), + self.test_data.isel(longitude=slice(i, i + 48)) + ) + for i in range(0, 144, 48) + ] + self.assertAllCloseChunks(result, expected) + + def test_multiple_subchunks_returns_multiple_datasets(self): + with self.pattern_from_testdata() as pattern: + result = ( + test_util.EagerPipeline() + | FilePatternToChunks(pattern, + chunks={"longitude": 48, "latitude": 24}) + ) + + expected = [ + ( + core.Key({"time": 0, "longitude": o, "latitude": a}), + self.test_data.isel(longitude=slice(o, o + 48), + latitude=slice(a, a + 24)) + ) + for o, a in itertools.product(range(0, 144, 48), range(0, 73, 24)) + ] + + self.assertAllCloseChunks(result, expected) + + @parameterized.parameters( + dict(time_step=479, longitude_step=47), + dict(time_step=365, longitude_step=72), + dict(time_step=292, longitude_step=71), + dict(time_step=291, longitude_step=48), + ) + def test_multiple_datasets_returns_multiple_datasets( + self, + time_step: int, + longitude_step: int + ): + expected = [ + ( + core.Key({"time": t, "latitude": 0, "longitude": o}), + self.test_data.isel( + time=slice(t, t + time_step), + longitude=slice(o, o + longitude_step) + ) + ) for t, o in itertools.product( + range(0, 360 * 4, time_step), + range(0, 144, longitude_step) + ) + ] + with self.multifile_pattern(time_step, longitude_step) as pattern: + actual = test_util.EagerPipeline() | FilePatternToChunks(pattern) + + self.assertAllCloseChunks(actual, expected) + + @parameterized.parameters( + dict(time_step=365, longitude_step=72, chunks={"latitude": 36}), + dict(time_step=365, longitude_step=72, chunks={"longitude": 36}), + dict(time_step=365, longitude_step=72, + chunks={"longitude": 36, "latitude": 66}), + ) + def test_multiple_datasets_with_subchunks_returns_multiple_datasets( + self, + time_step: int, + longitude_step: int, + chunks: Dict[str, int], + ): + + expected = [] + for t, o in itertools.product(range(0, 360 * 4, time_step), + range(0, 144, longitude_step)): + expected.extend( + split_chunks( + core.Key({"latitude": 0, "longitude": o, "time": t}), + self.test_data.isel( + time=slice(t, t + time_step), + longitude=slice(o, o + longitude_step) + ), + chunks) + ) + with self.multifile_pattern(time_step, longitude_step) as pattern: + actual = test_util.EagerPipeline() | FilePatternToChunks( + pattern, + chunks=chunks + ) + + self.assertAllCloseChunks(actual, expected)