Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the references to _file_obj outside low level code paths, change to _close #4809

Merged
merged 18 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,6 @@ def maybe_decode_store(store, chunks):

else:
ds2 = ds
ds2._file_obj = ds._file_obj
return ds2

filename_or_obj = _normalize_path(filename_or_obj)
Expand Down Expand Up @@ -701,8 +700,6 @@ def open_dataarray(
else:
(data_array,) = dataset.data_vars.values()

data_array._file_obj = dataset._file_obj

# Reset names if they were changed during saving
# to ensure that we can 'roundtrip' perfectly
if DATAARRAY_NAME in dataset.attrs:
Expand All @@ -716,14 +713,14 @@ def open_dataarray(


class _MultiFileCloser:
__slots__ = ("file_objs",)
__slots__ = ("closers",)

def __init__(self, file_objs):
self.file_objs = file_objs
def __init__(self, closers):
self.closers = closers

def close(self):
for f in self.file_objs:
f.close()
def __call__(self):
for dataset_close in self.closers:
dataset_close()


def open_mfdataset(
Expand Down Expand Up @@ -918,14 +915,14 @@ def open_mfdataset(
getattr_ = getattr

datasets = [open_(p, **open_kwargs) for p in paths]
file_objs = [getattr_(ds, "_file_obj") for ds in datasets]
closers = [getattr_(ds, "_close") for ds in datasets]
if preprocess is not None:
datasets = [preprocess(ds) for ds in datasets]

if parallel:
# calling compute here will return the datasets/file_objs lists,
# the underlying datasets will still be stored as dask arrays
datasets, file_objs = dask.compute(datasets, file_objs)
datasets, closers = dask.compute(datasets, closers)

# Combine all datasets, closing them in case of a ValueError
try:
Expand Down Expand Up @@ -963,7 +960,7 @@ def open_mfdataset(
ds.close()
raise

combined._file_obj = _MultiFileCloser(file_objs)
combined._close = _MultiFileCloser(closers)

# read global attributes from the attrs_file or from the first dataset
if attrs_file is not None:
Expand Down
2 changes: 0 additions & 2 deletions xarray/backends/apiv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def _dataset_from_backend_dataset(
**extra_tokens,
)

ds._file_obj = backend_ds._file_obj

# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
if isinstance(filename_or_obj, str):
Expand Down
11 changes: 7 additions & 4 deletions xarray/backends/rasterio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
if cache and chunks is None:
data = indexing.MemoryCachedArray(data)

result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs)
result = DataArray(
data=data,
dims=("band", "y", "x"),
coords=coords,
attrs=attrs,
)
result.set_close(manager.close)

if chunks is not None:
from dask.base import tokenize
Expand All @@ -360,7 +366,4 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc
name_prefix = "open_rasterio-%s" % token
result = result.chunk(chunks, name_prefix=name_prefix, token=token)

# Make the file closeable
result._file_obj = manager

return result
3 changes: 1 addition & 2 deletions xarray/backends/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def open_backend_dataset_store(
decode_timedelta=None,
):
vars, attrs = store.load()
file_obj = store
encoding = store.get_encoding()

vars, attrs, coord_names = conventions.decode_cf_variables(
Expand All @@ -35,8 +34,8 @@ def open_backend_dataset_store(
)

ds = Dataset(vars, attrs=attrs)
ds.set_close(store.close)
ds = ds.set_coords(coord_names.intersection(vars))
ds._file_obj = file_obj
ds.encoding = encoding

return ds
Expand Down
6 changes: 3 additions & 3 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,12 +576,12 @@ def decode_cf(
vars = obj._variables
attrs = obj.attrs
extra_coords = set(obj.coords)
file_obj = obj._file_obj
close = obj._close
encoding = obj.encoding
elif isinstance(obj, AbstractDataStore):
vars, attrs = obj.load()
extra_coords = set()
file_obj = obj
close = obj.close
encoding = obj.get_encoding()
else:
raise TypeError("can only decode Dataset or DataStore objects")
Expand All @@ -598,8 +598,8 @@ def decode_cf(
decode_timedelta=decode_timedelta,
)
ds = Dataset(vars, attrs=attrs)
ds.set_close(close)
ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
ds._file_obj = file_obj
ds.encoding = encoding

return ds
Expand Down
14 changes: 10 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
Expand Down Expand Up @@ -330,7 +331,9 @@ def get_squeeze_dims(
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""

__slots__ = ()
_close: Optional[Callable[[], None]]

__slots__ = ("_close",)

_rolling_exp_cls = RollingExp

Expand Down Expand Up @@ -1263,11 +1266,14 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):

return ops.where_method(self, cond, other)

def set_close(self, close: Optional[Callable[[], None]]) -> None:
alexamici marked this conversation as resolved.
Show resolved Hide resolved
self._close = close

def close(self: Any) -> None:
"""Close any files linked to this object"""
if self._file_obj is not None:
self._file_obj.close()
self._file_obj = None
if self._close is not None:
self._close()
self._close = None

def isnull(self, keep_attrs: bool = None):
"""Test each value in the array for whether it is a missing value.
Expand Down
22 changes: 17 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,15 @@ class DataArray(AbstractArray, DataWithCoords):

_cache: Dict[str, Any]
_coords: Dict[Any, Variable]
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_name: Optional[Hashable]
_variable: Variable

__slots__ = (
"_cache",
"_coords",
"_file_obj",
"_close",
"_indexes",
"_name",
"_variable",
Expand Down Expand Up @@ -421,22 +422,29 @@ def __init__(
# public interface.
self._indexes = indexes

self._file_obj = None
self._close = None

def _replace(
self,
variable: Variable = None,
coords=None,
name: Union[Hashable, None, Default] = _default,
indexes=None,
close: Union[Callable[[], None], None, Default] = _default,
) -> "DataArray":
if variable is None:
variable = self.variable
if coords is None:
coords = self._coords
if name is _default:
name = self.name
return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes)
if close is _default:
close = self._close
replaced = type(self)(
variable, coords, name=name, fastpath=True, indexes=indexes
)
replaced.set_close(close)
return replaced
alexamici marked this conversation as resolved.
Show resolved Hide resolved

def _replace_maybe_drop_dims(
self, variable: Variable, name: Union[Hashable, None, Default] = _default
Expand Down Expand Up @@ -492,7 +500,8 @@ def _from_temp_dataset(
variable = dataset._variables.pop(_THIS_ARRAY)
coords = dataset._variables
indexes = dataset._indexes
return self._replace(variable, coords, name, indexes=indexes)
close = dataset._close
return self._replace(variable, coords, name, indexes=indexes, close=close)

def _to_dataset_split(self, dim: Hashable) -> Dataset:
""" splits dataarray along dimension 'dim' """
Expand Down Expand Up @@ -536,7 +545,10 @@ def _to_dataset_whole(
indexes = self._indexes

coord_names = set(self._coords)
dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes)
close = self._close
dataset = Dataset._construct_direct(
variables, coord_names, indexes=indexes, close=close
)
return dataset

def to_dataset(
Expand Down
45 changes: 33 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
_coord_names: Set[Hashable]
_dims: Dict[Hashable, int]
_encoding: Optional[Dict[Hashable, Any]]
_close: Optional[Callable[[], None]]
_indexes: Optional[Dict[Hashable, pd.Index]]
_variables: Dict[Hashable, Variable]

Expand All @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
"_coord_names",
"_dims",
"_encoding",
"_file_obj",
"_close",
"_indexes",
"_variables",
"__weakref__",
Expand Down Expand Up @@ -687,7 +688,7 @@ def __init__(
)

self._attrs = dict(attrs) if attrs is not None else None
self._file_obj = None
self._close = None
self._encoding = None
self._variables = variables
self._coord_names = coord_names
Expand All @@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset":
if decoder:
variables, attributes = decoder(variables, attributes)
obj = cls(variables, attrs=attributes)
obj._file_obj = store
obj._close = store.close
return obj

@property
Expand Down Expand Up @@ -876,7 +877,7 @@ def __dask_postcompute__(self):
self._attrs,
self._indexes,
self._encoding,
self._file_obj,
self._close,
)
return self._dask_postcompute, args

Expand All @@ -896,7 +897,7 @@ def __dask_postpersist__(self):
self._attrs,
self._indexes,
self._encoding,
self._file_obj,
self._close,
)
return self._dask_postpersist, args

Expand Down Expand Up @@ -1007,7 +1008,7 @@ def _construct_direct(
attrs=None,
indexes=None,
encoding=None,
file_obj=None,
close=None,
):
"""Shortcut around __init__ for internal use when we want to skip
costly validation
Expand All @@ -1020,7 +1021,7 @@ def _construct_direct(
obj._dims = dims
obj._indexes = indexes
obj._attrs = attrs
obj._file_obj = file_obj
obj._close = close
obj._encoding = encoding
return obj

Expand All @@ -1033,6 +1034,7 @@ def _replace(
indexes: Union[Dict[Any, pd.Index], None, Default] = _default,
encoding: Union[dict, None, Default] = _default,
inplace: bool = False,
close: Union[Callable[[], None], None, Default] = _default,
) -> "Dataset":
"""Fastpath constructor for internal use.

Expand All @@ -1055,6 +1057,8 @@ def _replace(
self._indexes = indexes
if encoding is not _default:
self._encoding = encoding
if close is not _default:
self._close = close
obj = self
else:
if variables is None:
Expand All @@ -1069,8 +1073,10 @@ def _replace(
indexes = copy.copy(self._indexes)
if encoding is _default:
encoding = copy.copy(self._encoding)
if close is _default:
close = self._close
obj = self._construct_direct(
variables, coord_names, dims, attrs, indexes, encoding
variables, coord_names, dims, attrs, indexes, encoding, close
)
return obj

Expand Down Expand Up @@ -1330,7 +1336,15 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray":
else:
indexes = {k: v for k, v in self._indexes.items() if k in coords}

return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True)
da = DataArray(
variable,
coords,
name=name,
indexes=indexes,
fastpath=True,
)
da.set_close(self._close)
return da

def __copy__(self) -> "Dataset":
return self.copy(deep=False)
Expand Down Expand Up @@ -2122,7 +2136,7 @@ def isel(
attrs=self._attrs,
indexes=indexes,
encoding=self._encoding,
file_obj=self._file_obj,
close=self._close,
)

def _isel_fancy(
Expand Down Expand Up @@ -4786,9 +4800,16 @@ def to_array(self, dim="variable", name=None):

dims = (dim,) + broadcast_vars[0].dims

return DataArray(
data, coords, dims, attrs=self.attrs, name=name, indexes=indexes
da = DataArray(
data,
coords,
dims,
attrs=self.attrs,
name=name,
indexes=indexes,
)
da.set_close(self._close)
return da

def _normalize_dim_order(
self, dim_order: List[Hashable] = None
Expand Down