diff --git a/doc/whats-new.rst b/doc/whats-new.rst index db10ec653c5..5465cb7761e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -103,6 +103,8 @@ Internal Changes By `Maximilian Roos `_. - Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release + all resources. (:pull:`#4809`), By `Alessandro Amici `_. .. _whats-new.0.16.2: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index faa7e6cf3d3..30f9532b29a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._file_obj = ds._file_obj + ds2.set_close(ds._close) return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,7 +701,7 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._file_obj = dataset._file_obj + data_array.set_close(dataset._close) # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly @@ -715,17 +715,6 @@ def open_dataarray( return data_array -class _MultiFileCloser: - __slots__ = ("file_objs",) - - def __init__(self, file_objs): - self.file_objs = file_objs - - def close(self): - for f in self.file_objs: - f.close() - - def open_mfdataset( paths, chunks=None, @@ -918,14 +907,14 @@ def open_mfdataset( getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, "_file_obj") for ds in datasets] + closers = [getattr_(ds, "_close") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays - datasets, file_objs = dask.compute(datasets, file_objs) + datasets, closers = dask.compute(datasets, closers) # Combine all datasets, closing them in case of a ValueError try: @@ -963,7 +952,11 @@ def open_mfdataset( ds.close() raise - combined._file_obj = _MultiFileCloser(file_objs) + def multi_file_closer(): + for closer in closers: + closer() + + combined.set_close(multi_file_closer) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 0f98291983d..d31fc9ea773 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._file_obj = backend_ds._file_obj + ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a0500c7e1c2..c689c1e99d7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable - result._file_obj = manager + result.set_close(manager.close) return result diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d314a9c3ca9..20fa13af202 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -19,7 +19,6 @@ def open_backend_dataset_store( decode_timedelta=None, ): vars, attrs = store.load() - file_obj = store encoding = store.get_encoding() vars, attrs, coord_names = conventions.decode_cf_variables( @@ -36,7 +35,7 @@ def open_backend_dataset_store( ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.intersection(vars)) - ds._file_obj = file_obj + ds.set_close(store.close) ds.encoding = encoding return ds diff --git a/xarray/conventions.py b/xarray/conventions.py index bb0b92c77a1..e33ae53b31d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -576,12 +576,12 @@ def decode_cf( vars = obj._variables attrs = obj.attrs extra_coords = set(obj.coords) - file_obj = obj._file_obj + close = obj._close encoding = obj.encoding elif isinstance(obj, AbstractDataStore): vars, attrs = obj.load() extra_coords = set() - file_obj = obj + close = obj.close encoding = obj.get_encoding() else: raise TypeError("can only decode Dataset or DataStore objects") @@ -599,7 +599,7 @@ def decode_cf( ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._file_obj = file_obj + ds.set_close(close) ds.encoding = encoding return ds diff --git a/xarray/core/common.py b/xarray/core/common.py index 283114770cf..a69ba03a7a4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -11,6 +11,7 @@ Iterator, List, Mapping, + Optional, Tuple, TypeVar, Union, @@ -330,7 +331,9 @@ def get_squeeze_dims( class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - __slots__ = () + _close: Optional[Callable[[], None]] + + __slots__ = ("_close",) _rolling_exp_cls = RollingExp @@ -1263,11 +1266,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) + def set_close(self, close: Optional[Callable[[], None]]) -> None: + """Register the function that releases any resources linked to this object. + + This method controls how xarray cleans up resources associated + with this object when the ``.close()`` method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. + + Parameters + ---------- + close : callable + The function that when called like ``close()`` releases + any resources linked to this object. + """ + self._close = close + def close(self: Any) -> None: - """Close any files linked to this object""" - if self._file_obj is not None: - self._file_obj.close() - self._file_obj = None + """Release any resources linked to this object.""" + if self._close is not None: + self._close() + self._close = None def isnull(self, keep_attrs: bool = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6fdda8fc418..e13ea44baad 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _name: Optional[Hashable] _variable: Variable @@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords): __slots__ = ( "_cache", "_coords", - "_file_obj", + "_close", "_indexes", "_name", "_variable", @@ -421,7 +422,7 @@ def __init__( # public interface. self._indexes = indexes - self._file_obj = None + self._close = None def _replace( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7edc2fab067..136edffb202 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _variables: Dict[Hashable, Variable] @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): "_coord_names", "_dims", "_encoding", - "_file_obj", + "_close", "_indexes", "_variables", "__weakref__", @@ -687,7 +688,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._file_obj = None + self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names @@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._file_obj = store + obj.set_close(store.close) return obj @property @@ -876,7 +877,7 @@ def __dask_postcompute__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postcompute, args @@ -896,7 +897,7 @@ def __dask_postpersist__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postpersist, args @@ -1007,7 +1008,7 @@ def _construct_direct( attrs=None, indexes=None, encoding=None, - file_obj=None, + close=None, ): """Shortcut around __init__ for internal use when we want to skip costly validation @@ -1020,7 +1021,7 @@ def _construct_direct( obj._dims = dims obj._indexes = indexes obj._attrs = attrs - obj._file_obj = file_obj + obj._close = close obj._encoding = encoding return obj @@ -2122,7 +2123,7 @@ def isel( attrs=self._attrs, indexes=indexes, encoding=self._encoding, - file_obj=self._file_obj, + close=self._close, ) def _isel_fancy(