From aee8967da9c3fbb9e14d49519571c6c00d3ca110 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 26 Aug 2023 18:12:18 -0700 Subject: [PATCH] Prototype version of xarray_beam.Dataset xbeam.Dataset implements a high level API for distributed dataset operations at the Dataset rather than beam.PCollection level. PiperOrigin-RevId: 560407095 --- xarray_beam/__init__.py | 3 + xarray_beam/_src/dataset.py | 121 +++++++++++++++++++++++++++++++ xarray_beam/_src/dataset_test.py | 91 +++++++++++++++++++++++ 3 files changed, 215 insertions(+) create mode 100644 xarray_beam/_src/dataset.py create mode 100644 xarray_beam/_src/dataset_test.py diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 926397c..f0dc6d6 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -25,6 +25,9 @@ offsets_to_slices, validate_chunk ) +from xarray_beam._src.dataset import ( + Dataset, +) from xarray_beam._src.rechunk import ( ConsolidateChunks, ConsolidateVariables, diff --git a/xarray_beam/_src/dataset.py b/xarray_beam/_src/dataset.py new file mode 100644 index 0000000..30ef214 --- /dev/null +++ b/xarray_beam/_src/dataset.py @@ -0,0 +1,121 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A high-level interface for Xarray-Beam datasets. + +Usage example (not fully implemented yet!): + + import xarray_beam as xbeam + + transform = ( + xbeam.Dataset.from_zarr(input_path) + .rechunk({'time': -1, 'latitude': 10, 'longitude': 10}) + .map_blocks(lambda x: x.median('time')) + .to_zarr(output_path) + ) + with beam.Pipeline() as p: + p | transform +""" +from __future__ import annotations + +import collections +from collections import abc +import dataclasses +import itertools +import os.path +import tempfile + +import apache_beam as beam +import xarray +from xarray_beam._src import core +from xarray_beam._src import zarr + + +class _CountNamer: + + def __init__(self): + self._counts = collections.defaultdict(itertools.count) + + def apply(self, name: str) -> str: + return f'{name}_{next(self._counts[name])}' + + +_get_label = _CountNamer().apply + + +@dataclasses.dataclass +class Dataset: + """Experimental high-level representation of an Xarray-Beam dataset.""" + + template: xarray.Dataset + chunks: dict[str, int] + split_vars: bool + ptransform: beam.PTransform + + @classmethod + def from_xarray( + cls, + source: xarray.Dataset, + chunks: abc.Mapping[str, int], + split_vars: bool = False, + ) -> Dataset: + """Create an xarray_beam.Dataset from an xarray.Dataset.""" + template = zarr.make_template(source) + ptransform = _get_label('from_xarray') >> core.DatasetToChunks( + source, chunks, split_vars + ) + return cls(template, dict(chunks), split_vars, ptransform) + + @classmethod + def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset: + """Create an xarray_beam.Dataset from a zarr file.""" + source, chunks = zarr.open_zarr(path) + result = cls.from_xarray(source, chunks, split_vars) + result.ptransform = _get_label('from_zarr') >> result.ptransform + return result + + def to_zarr(self, path: str) -> beam.PTransform: + """Write to a Zarr file.""" + return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr( + path, self.template, self.chunks + ) + + def collect_with_direct_runner(self) -> xarray.Dataset: + """Collect a dataset in memory by writing it to a temp file.""" + # TODO(shoyer): generalize this function to something that support + # alternative runners can we figure out a suitable temp file location for + # distributed runners? + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = os.path.join(temp_dir, 'tmp.zarr') + with beam.Pipeline(runner='DirectRunner') as pipeline: + pipeline |= self.to_zarr(temp_path) + return xarray.open_zarr(temp_path).compute() + + # TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc + + @property + def sizes(self) -> dict[str, int]: + """Size of each dimension on this dataset.""" + return dict(self.template.sizes) + + def pipe(self, func, *args, **kwargs): + return func(*args, **kwargs) + + def __repr__(self): + base = repr(self.template) + chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items()) + return base.replace( + '', + f'', + ) diff --git a/xarray_beam/_src/dataset_test.py b/xarray_beam/_src/dataset_test.py new file mode 100644 index 0000000..156276f --- /dev/null +++ b/xarray_beam/_src/dataset_test.py @@ -0,0 +1,91 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from absl.testing import absltest +from absl.testing import parameterized +import apache_beam as beam +import numpy as np +import xarray +import xarray_beam as xbeam +from xarray_beam._src import test_util + + +class DatasetTest(test_util.TestCase): + + def test_from_xarray(self): + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) + self.assertIsInstance(beam_ds, xbeam.Dataset) + self.assertEqual(beam_ds.sizes, {'x': 10}) + self.assertEqual(beam_ds.template.keys(), {'foo'}) + self.assertEqual(beam_ds.chunks, {'x': 5}) + self.assertFalse(beam_ds.split_vars) + self.assertRegex(beam_ds.ptransform.label, r'^from_xarray_\d+$') + self.assertEqual( + repr(beam_ds), + textwrap.dedent(""" + + Dimensions: (x: 10) + Dimensions without coordinates: x + Data variables: + foo (x) int64 dask.array + """).strip(), + ) + expected = [ + (xbeam.Key({'x': 0}), ds.head(x=5)), + (xbeam.Key({'x': 5}), ds.tail(x=5)), + ] + actual = test_util.EagerPipeline() | beam_ds.ptransform + self.assertIdenticalChunks(expected, actual) + + def test_collect_with_direct_runner(self): + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) + collected = beam_ds.collect_with_direct_runner() + xarray.testing.assert_identical(ds, collected) + + @parameterized.parameters( + dict(split_vars=False), + dict(split_vars=True), + ) + def test_from_zarr(self, split_vars): + temp_dir = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + ds.chunk({'x': 5}).to_zarr(temp_dir) + + beam_ds = xbeam.Dataset.from_zarr(temp_dir, split_vars) + + self.assertRegex(beam_ds.ptransform.label, r'^from_zarr_\d+$') + self.assertEqual(beam_ds.chunks, {'x': 5}) + self.assertEqual(beam_ds.split_vars, split_vars) + + collected = beam_ds.collect_with_direct_runner() + xarray.testing.assert_identical(ds, collected) + + def test_to_zarr(self): + temp_dir = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(10))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) + to_zarr = beam_ds.to_zarr(temp_dir) + + self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$') + with beam.Pipeline() as p: + p |= to_zarr + opened = xarray.open_zarr(temp_dir).compute() + xarray.testing.assert_identical(ds, opened) + + +if __name__ == '__main__': + absltest.main()