Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing FilePatternToChunks: IO with Pangeo-Forge's FilePattern interface. #31

Merged
merged 20 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fd2019a
Introducing FilePatternToChunks: IO with Pangeo-Forge's FilePattern i…
alxmrs Aug 9, 2021
48162d2
Including scipy as test dependency (so we can write NetCDF files with…
alxmrs Aug 9, 2021
f211072
Initial PR feedback, WIP
alxmrs Sep 14, 2021
c125364
Better translation of FilePatternIndex to Key.
alxmrs Sep 15, 2021
5896d45
Initial cleanup; fixing broken CI.
alxmrs Sep 15, 2021
e0f22df
`expand` step is splittable (only uses a `create` fn).
alxmrs Sep 15, 2021
83fe582
Clean up file whitespace
alxmrs Sep 15, 2021
8cf7117
Revert create strategy.
alxmrs Sep 15, 2021
8d39374
Simplified FilePatternToChunks transform -- no split_chunks / sub_chu…
alxmrs Sep 20, 2021
bb4675b
Fixed broken unit tests.
alxmrs Sep 20, 2021
1df6e81
Using an "all close" instead of an "identical" assert.
alxmrs Sep 20, 2021
b231673
Single chunks are also now using "all close".
alxmrs Sep 20, 2021
c9d9f8c
Updating file open capability to support grib files.
alxmrs Sep 21, 2021
dd14005
Added back sub-chunks; `Create` + `FlatMap` is now splittable.
alxmrs Sep 22, 2021
73671d2
Renaming 'sub_chunks' to 'chunks'.
alxmrs Sep 22, 2021
209632d
_open_dataset() has error handling for open_local call.
alxmrs Sep 22, 2021
d1e91a8
Merge branch 'pangeo-fp' of github.com:alxmrs/xarray-beam into pangeo-fp
alxmrs Sep 22, 2021
c3a668a
Added 'local_copy' option.
alxmrs Sep 22, 2021
d7f284b
Imperative 'local_copy' flag instead of a fallback.
alxmrs Sep 22, 2021
0ee8980
`open_local` downloads file to a temporary directory.
alxmrs Sep 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
'absl-py',
'pandas',
'pytest',
'pangeo-forge-recipes',
'scipy',
'h5netcdf'
]

setuptools.setup(
Expand Down
151 changes: 151 additions & 0 deletions xarray_beam/_src/pangeo_forge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""IO with Pangeo-Forge."""
import contextlib
import tempfile
from typing import (
Dict,
Iterator,
Optional,
Mapping,
Tuple,
)

import apache_beam as beam
import fsspec
import xarray
from apache_beam.io.filesystems import FileSystems

from xarray_beam._src import core, rechunk


def _zero_dimensions(dataset: xarray.Dataset) -> Mapping[str, int]:
return {dim: 0 for dim in dataset.dims.keys()}


def _expand_dimensions_by_key(
dataset: xarray.Dataset,
index: 'FilePatternIndex',
pattern: 'FilePattern'
) -> xarray.Dataset:
"""Expand the dimensions of the `Dataset` by offsets found in the `Key`."""
combine_dims_by_name = {
combine_dim.name: combine_dim for combine_dim in pattern.combine_dims
}
index_by_name = {
idx.name: idx for idx in index
}

if not combine_dims_by_name:
return dataset

for dim_key in index_by_name.keys():
# skip expanding dimensions if they already exist
if dim_key in dataset.dims:
continue

try:
combine_dim = combine_dims_by_name[dim_key]
except KeyError:
raise ValueError(
f"could not find CombineDim named {dim_key!r} in pattern {pattern!r}."
)

dim_val = combine_dim.keys[index_by_name[dim_key].index]
dataset = dataset.expand_dims(**{dim_key: [dim_val]})

return dataset


class FilePatternToChunks(beam.PTransform):
"""Open data described by a Pangeo-Forge `FilePattern` into keyed chunks."""

from pangeo_forge_recipes.patterns import FilePattern, FilePatternIndex

def __init__(
self,
pattern: 'FilePattern',
chunks: Optional[Mapping[str, int]] = None,
local_copy: bool = False,
xarray_open_kwargs: Optional[Dict] = None
):
"""Initialize FilePatternToChunks.

TODO(#29): Currently, `MergeDim`s are not supported.

Args:
pattern: a `FilePattern` describing a dataset.
chunks: split each open dataset into smaller chunks. If not set, the
transform will return one file per chunk.
local_copy: Open files from the pattern with local copies instead of a
buffered reader.
xarray_open_kwargs: keyword arguments to pass to `xarray.open_dataset()`.
"""
self.pattern = pattern
self.chunks = chunks
self.local_copy = local_copy
self.xarray_open_kwargs = xarray_open_kwargs or {}
self._max_size_idx = {}

if pattern.merge_dims:
raise ValueError("patterns with `MergeDim`s are not supported.")

@contextlib.contextmanager
def _open_dataset(self, path: str) -> xarray.Dataset:
"""Open as an XArray Dataset, sometimes with local caching."""
if self.local_copy:
with tempfile.TemporaryDirectory() as tmpdir:
local_file = fsspec.open_local(
f"simplecache::{path}",
simplecache={'cache_storage': tmpdir}
)
yield xarray.open_dataset(local_file, **self.xarray_open_kwargs)
else:
with FileSystems().open(path) as file:
yield xarray.open_dataset(file, **self.xarray_open_kwargs)

def _open_chunks(
self,
index: 'FilePatternIndex',
path: str
) -> Iterator[Tuple[core.Key, xarray.Dataset]]:
"""Open datasets into chunks with XArray."""
with self._open_dataset(path) as dataset:

dataset = _expand_dimensions_by_key(dataset, index, self.pattern)

if not self._max_size_idx:
self._max_size_idx = dataset.sizes

base_key = core.Key(_zero_dimensions(dataset)).with_offsets(
**{dim.name: self._max_size_idx[dim.name] * dim.index for dim in index}
)

num_threads = len(dataset.data_vars)

# If chunks is not set by the user, treat the dataset as a single chunk.
if self.chunks is None:
yield base_key, dataset.compute(num_workers=num_threads)
return

for new_key, chunk in rechunk.split_chunks(base_key, dataset,
self.chunks):
yield new_key, chunk.compute(num_workers=num_threads)

def expand(self, pcoll):
return (
pcoll
| beam.Create(list(self.pattern.items()))
| beam.FlatMapTuple(self._open_chunks)
)
205 changes: 205 additions & 0 deletions xarray_beam/_src/pangeo_forge_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_beam._src.pangeo."""

import contextlib
import itertools
import tempfile
from typing import Dict

import numpy as np
from absl.testing import parameterized
from pangeo_forge_recipes.patterns import (
FilePattern,
ConcatDim,
DimIndex,
CombineOp
)

from xarray_beam import split_chunks
from xarray_beam._src import core
from xarray_beam._src import test_util
from xarray_beam._src.pangeo_forge import (
FilePatternToChunks,
_expand_dimensions_by_key
)


class ExpandDimensionsByKeyTest(test_util.TestCase):

def setUp(self):
self.test_data = test_util.dummy_era5_surface_dataset()
self.level = ConcatDim("level", list(range(91, 100)))
self.pattern = FilePattern(lambda level: f"gs://dir/{level}.nc", self.level)

def test_expands_dimensions(self):
for i, (index, _) in enumerate(self.pattern.items()):
actual = _expand_dimensions_by_key(
self.test_data, index, self.pattern
)

expected_dims = dict(self.test_data.dims)
expected_dims.update({"level": 1})

self.assertEqual(expected_dims, dict(actual.dims))
self.assertEqual(np.array([self.level.keys[i]]), actual["level"])

def test_raises_error_when_dataset_is_not_found(self):
index = (DimIndex('boat', 0, 1, CombineOp.CONCAT),)
with self.assertRaisesRegex(ValueError, "boat"):
_expand_dimensions_by_key(
self.test_data, index, self.pattern
)


class FilePatternToChunksTest(test_util.TestCase):

def setUp(self):
self.test_data = test_util.dummy_era5_surface_dataset()

@contextlib.contextmanager
def pattern_from_testdata(self) -> FilePattern:
"""Produces a FilePattern for a temporary NetCDF file of test data."""
with tempfile.TemporaryDirectory() as tmpdir:
target = f'{tmpdir}/era5.nc'
self.test_data.to_netcdf(target)
yield FilePattern(lambda: target)

@contextlib.contextmanager
def multifile_pattern(
self,
time_step: int = 479,
longitude_step: int = 47
) -> FilePattern:
"""Produces a FilePattern for a temporary NetCDF file of test data."""
time_dim = ConcatDim('time', list(range(0, 360 * 4, time_step)))
longitude_dim = ConcatDim('longitude', list(range(0, 144, longitude_step)))

with tempfile.TemporaryDirectory() as tmpdir:
def make_path(time: int, longitude: int) -> str:
return f'{tmpdir}/era5-{time}-{longitude}.nc'

for time in time_dim.keys:
for long in longitude_dim.keys:
chunk = self.test_data.isel(
time=slice(time, time + time_step),
longitude=slice(long, long + longitude_step)
)
chunk.to_netcdf(make_path(time, long))
yield FilePattern(make_path, time_dim, longitude_dim)

def test_returns_single_dataset(self):
expected = [
(core.Key({"time": 0, "latitude": 0, "longitude": 0}), self.test_data)
]
with self.pattern_from_testdata() as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern)

self.assertAllCloseChunks(actual, expected)

def test_single_subchunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
| FilePatternToChunks(pattern, chunks={"longitude": 48})
)

expected = [
(
core.Key({"time": 0, "latitude": 0, "longitude": i}),
self.test_data.isel(longitude=slice(i, i + 48))
)
for i in range(0, 144, 48)
]
self.assertAllCloseChunks(result, expected)

def test_multiple_subchunks_returns_multiple_datasets(self):
with self.pattern_from_testdata() as pattern:
result = (
test_util.EagerPipeline()
| FilePatternToChunks(pattern,
chunks={"longitude": 48, "latitude": 24})
)

expected = [
(
core.Key({"time": 0, "longitude": o, "latitude": a}),
self.test_data.isel(longitude=slice(o, o + 48),
latitude=slice(a, a + 24))
)
for o, a in itertools.product(range(0, 144, 48), range(0, 73, 24))
]

self.assertAllCloseChunks(result, expected)

@parameterized.parameters(
dict(time_step=479, longitude_step=47),
dict(time_step=365, longitude_step=72),
dict(time_step=292, longitude_step=71),
dict(time_step=291, longitude_step=48),
)
def test_multiple_datasets_returns_multiple_datasets(
self,
time_step: int,
longitude_step: int
):
expected = [
(
core.Key({"time": t, "latitude": 0, "longitude": o}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
)
) for t, o in itertools.product(
range(0, 360 * 4, time_step),
range(0, 144, longitude_step)
)
]
with self.multifile_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(pattern)

self.assertAllCloseChunks(actual, expected)

@parameterized.parameters(
dict(time_step=365, longitude_step=72, chunks={"latitude": 36}),
dict(time_step=365, longitude_step=72, chunks={"longitude": 36}),
dict(time_step=365, longitude_step=72,
chunks={"longitude": 36, "latitude": 66}),
)
def test_multiple_datasets_with_subchunks_returns_multiple_datasets(
self,
time_step: int,
longitude_step: int,
chunks: Dict[str, int],
):

expected = []
for t, o in itertools.product(range(0, 360 * 4, time_step),
range(0, 144, longitude_step)):
expected.extend(
split_chunks(
core.Key({"latitude": 0, "longitude": o, "time": t}),
self.test_data.isel(
time=slice(t, t + time_step),
longitude=slice(o, o + longitude_step)
),
chunks)
)
with self.multifile_pattern(time_step, longitude_step) as pattern:
actual = test_util.EagerPipeline() | FilePatternToChunks(
pattern,
chunks=chunks
)

self.assertAllCloseChunks(actual, expected)