Skip to content

Commit

Permalink
Prototype version of xarray_beam.Dataset
Browse files Browse the repository at this point in the history
xbeam.Dataset implements a high level API for distributed dataset operations at the Dataset rather than beam.PCollection level.

PiperOrigin-RevId: 560407095
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Aug 29, 2023
1 parent 4f4fcb9 commit aee8967
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
offsets_to_slices,
validate_chunk
)
from xarray_beam._src.dataset import (
Dataset,
)
from xarray_beam._src.rechunk import (
ConsolidateChunks,
ConsolidateVariables,
Expand Down
121 changes: 121 additions & 0 deletions xarray_beam/_src/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A high-level interface for Xarray-Beam datasets.
Usage example (not fully implemented yet!):
import xarray_beam as xbeam
transform = (
xbeam.Dataset.from_zarr(input_path)
.rechunk({'time': -1, 'latitude': 10, 'longitude': 10})
.map_blocks(lambda x: x.median('time'))
.to_zarr(output_path)
)
with beam.Pipeline() as p:
p | transform
"""
from __future__ import annotations

import collections
from collections import abc
import dataclasses
import itertools
import os.path
import tempfile

import apache_beam as beam
import xarray
from xarray_beam._src import core
from xarray_beam._src import zarr


class _CountNamer:

def __init__(self):
self._counts = collections.defaultdict(itertools.count)

def apply(self, name: str) -> str:
return f'{name}_{next(self._counts[name])}'


_get_label = _CountNamer().apply


@dataclasses.dataclass
class Dataset:
"""Experimental high-level representation of an Xarray-Beam dataset."""

template: xarray.Dataset
chunks: dict[str, int]
split_vars: bool
ptransform: beam.PTransform

@classmethod
def from_xarray(
cls,
source: xarray.Dataset,
chunks: abc.Mapping[str, int],
split_vars: bool = False,
) -> Dataset:
"""Create an xarray_beam.Dataset from an xarray.Dataset."""
template = zarr.make_template(source)
ptransform = _get_label('from_xarray') >> core.DatasetToChunks(
source, chunks, split_vars
)
return cls(template, dict(chunks), split_vars, ptransform)

@classmethod
def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset:
"""Create an xarray_beam.Dataset from a zarr file."""
source, chunks = zarr.open_zarr(path)
result = cls.from_xarray(source, chunks, split_vars)
result.ptransform = _get_label('from_zarr') >> result.ptransform
return result

def to_zarr(self, path: str) -> beam.PTransform:
"""Write to a Zarr file."""
return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr(
path, self.template, self.chunks
)

def collect_with_direct_runner(self) -> xarray.Dataset:
"""Collect a dataset in memory by writing it to a temp file."""
# TODO(shoyer): generalize this function to something that support
# alternative runners can we figure out a suitable temp file location for
# distributed runners?

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = os.path.join(temp_dir, 'tmp.zarr')
with beam.Pipeline(runner='DirectRunner') as pipeline:
pipeline |= self.to_zarr(temp_path)
return xarray.open_zarr(temp_path).compute()

# TODO(shoyer): implement map_blocks, rechunking, merge, rename, mean, etc

@property
def sizes(self) -> dict[str, int]:
"""Size of each dimension on this dataset."""
return dict(self.template.sizes)

def pipe(self, func, *args, **kwargs):
return func(*args, **kwargs)

def __repr__(self):
base = repr(self.template)
chunks_str = ', '.join(f'{k}: {v}' for k, v in self.chunks.items())
return base.replace(
'<xarray.Dataset>',
f'<xarray_beam.Dataset[{chunks_str}][split_vars={self.split_vars}]>',
)
91 changes: 91 additions & 0 deletions xarray_beam/_src/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from absl.testing import absltest
from absl.testing import parameterized
import apache_beam as beam
import numpy as np
import xarray
import xarray_beam as xbeam
from xarray_beam._src import test_util


class DatasetTest(test_util.TestCase):

def test_from_xarray(self):
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
self.assertIsInstance(beam_ds, xbeam.Dataset)
self.assertEqual(beam_ds.sizes, {'x': 10})
self.assertEqual(beam_ds.template.keys(), {'foo'})
self.assertEqual(beam_ds.chunks, {'x': 5})
self.assertFalse(beam_ds.split_vars)
self.assertRegex(beam_ds.ptransform.label, r'^from_xarray_\d+$')
self.assertEqual(
repr(beam_ds),
textwrap.dedent("""
<xarray_beam.Dataset[x: 5][split_vars=False]>
Dimensions: (x: 10)
Dimensions without coordinates: x
Data variables:
foo (x) int64 dask.array<chunksize=(10,), meta=np.ndarray>
""").strip(),
)
expected = [
(xbeam.Key({'x': 0}), ds.head(x=5)),
(xbeam.Key({'x': 5}), ds.tail(x=5)),
]
actual = test_util.EagerPipeline() | beam_ds.ptransform
self.assertIdenticalChunks(expected, actual)

def test_collect_with_direct_runner(self):
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
collected = beam_ds.collect_with_direct_runner()
xarray.testing.assert_identical(ds, collected)

@parameterized.parameters(
dict(split_vars=False),
dict(split_vars=True),
)
def test_from_zarr(self, split_vars):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
ds.chunk({'x': 5}).to_zarr(temp_dir)

beam_ds = xbeam.Dataset.from_zarr(temp_dir, split_vars)

self.assertRegex(beam_ds.ptransform.label, r'^from_zarr_\d+$')
self.assertEqual(beam_ds.chunks, {'x': 5})
self.assertEqual(beam_ds.split_vars, split_vars)

collected = beam_ds.collect_with_direct_runner()
xarray.testing.assert_identical(ds, collected)

def test_to_zarr(self):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
to_zarr = beam_ds.to_zarr(temp_dir)

self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
with beam.Pipeline() as p:
p |= to_zarr
opened = xarray.open_zarr(temp_dir).compute()
xarray.testing.assert_identical(ds, opened)


if __name__ == '__main__':
absltest.main()

0 comments on commit aee8967

Please sign in to comment.