Skip to content
This repository has been archived by the owner on Apr 30, 2021. It is now read-only.

Avoid setting the .data attribute #156

Merged
merged 5 commits into from
Oct 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions esmlab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ def compute_time(self):
else:
groupby_coord = self.get_time_decoded(midpoint=False)

ds[self.time_coord_name].data = groupby_coord.data
ds[self.time_coord_name] = groupby_coord.data

if self.time_bound is not None:
ds[self.tb_name].data = self.time_bound.data
self.time_bound[self.time_coord_name].data = groupby_coord.data
self.time_bound[self.time_coord_name] = groupby_coord.data
self.time_bound_diff = self.compute_time_bound_diff(ds)

self._ds_time_computed = ds
Expand All @@ -101,6 +100,7 @@ def compute_time_bound_diff(self, ds):
if self.time_bound is not None:
time_bound_diff.name = self.tb_name + '_diff'
time_bound_diff.attrs = {}
# Compute
time_bound_diff.data = self.time_bound.diff(dim=self.tb_dim)[:, 0]
if self.tb_dim in time_bound_diff.coords:
time_bound_diff = time_bound_diff.drop(self.tb_dim)
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_original_metadata(self):
v: {
key: val
for key, val in self._ds[v].encoding.items()
if key in ['dtype', '_FillValue', 'missing_value']
if key in ['dtype', '_FillValue', 'missing_value', 'units', 'calendar']
}
for v in self._ds.variables
}
Expand Down Expand Up @@ -440,7 +440,7 @@ def compute_resample_times(self, ds, temporary_time_coord_name, time_dot, method
else:
time_values = self._ds_time_computed[self.time_coord_name].groupby(time_dot).mean()

ds[self.time_coord_name].data = time_values.data
ds[self.time_coord_name] = time_values.data

return ds

Expand All @@ -453,24 +453,13 @@ def compute_mon_climatology(self):
self._ds_time_computed.drop(self.static_variables)
.groupby(time_dot_month)
.mean(self.time_coord_name)
.rename({'month': self.time_coord_name})
)
computed_dset['month'] = computed_dset[self.time_coord_name].copy()
attrs = {'month': {'long_name': 'Month', 'units': 'month'}}
encoding = {
'month': {'dtype': 'int32', '_FillValue': None},
self.time_coord_name: {'dtype': 'float', '_FillValue': None},
}

if self.time_bound is not None:
time_data = computed_dset[self.tb_name] - computed_dset[self.tb_name][0, 0]
computed_dset[self.tb_name] = time_data
computed_dset[self.time_coord_name].data = (
computed_dset[self.tb_name].mean(self.tb_dim).data
)
encoding[self.tb_name] = {'dtype': 'float', '_FillValue': None}
encoding = {'month': {'dtype': 'int32', '_FillValue': None}}

return self.restore_dataset(computed_dset, attrs=attrs, encoding=encoding)
if self.tb_name in computed_dset.data_vars:
computed_dset = computed_dset.drop(self.tb_name)
return self.update_metadata(computed_dset, attrs, encoding)

@esmlab_xr_set_options(arithmetic_join='exact')
def compute_mon_anomaly(self, slice_mon_clim_time=None):
Expand Down Expand Up @@ -504,9 +493,9 @@ def compute_mon_anomaly(self, slice_mon_clim_time=None):
# reset month to become a variable
computed_dset = computed_dset.reset_coords('month')

computed_dset[self.time_coord_name].data = self.time.data
computed_dset[self.time_coord_name] = self.time

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep .data on RHS? But it seems like you're rewriting attributes later, in which case this might not matter.

if self.time_bound is not None:
computed_dset[self.tb_name].data = self.time_bound.data
computed_dset[self.tb_name] = self.time_bound

attrs = {'month': {'long_name': 'Month'}}
return self.restore_dataset(computed_dset, attrs=attrs)
Expand Down
24 changes: 10 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_esmlab_accessor():
attrs = {'calendar': 'noleap', 'units': 'days since 2000-01-01 00:00:00'}
ds.time.attrs = attrs
esm = ds.esmlab.set_time(time_coord_name='time')
xr.testing._assert_internal_invariants(esm._ds_time_computed)
# Time and Time bound Attributes
expected = dict(esm.time_attrs)
attrs['bounds'] = None
Expand Down Expand Up @@ -96,6 +97,7 @@ def test_datetime_cftime_exception():

def test_time_bound_var(dset, time_coord_name='time'):
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.tb_name, esm.tb_dim
expected = ('time_bound', 'd2')
assert results == expected
Expand All @@ -114,6 +116,7 @@ def test_get_time_attrs(dset, time_coord_name='time'):
'bounds': 'time_bound',
}
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.time_attrs
assert results == expected

Expand All @@ -122,12 +125,14 @@ def test_compute_time_var(dset, time_coord_name='time'):
idx = dset.indexes[time_coord_name]
assert isinstance(idx, pd.core.indexes.numeric.Index)
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.get_time_decoded()
assert isinstance(results, xr.DataArray)


def test_uncompute_time_var(dset, time_coord_name='time'):
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
ds = esm.compute_time_var()
assert ds[time_coord_name].dtype == np.dtype('O')
dset_with_uncomputed_time = esm.uncompute_time_var()
Expand All @@ -137,6 +142,7 @@ def test_uncompute_time_var(dset, time_coord_name='time'):
# For some strange reason, this case fails when using pytest parametrization
def test_sel_time_(dset):
esm = dset.esmlab.set_time()
xr.testing._assert_internal_invariants(esm._ds_time_computed)
dset = esm.sel_time(indexer_val=slice('1850-01-01', '1850-12-31'), year_offset=1850)
assert len(dset.time) == 12

Expand All @@ -155,10 +161,10 @@ def test_sel_time_(dset):
def test_mon_climatology(ds_name, decoded, variables, time_coord_name):
ds = esmlab.datasets.open_dataset(ds_name, decode_times=decoded)
esm = ds.esmlab.set_time(time_coord_name=time_coord_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is esm an xarray object, if so use xarray.testing._assert_internal_invariants(esm) to make sure you aren't breaking any assumptions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

esm here is an instance of esmlab.core.EsmlabAccessor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so check esm._ds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so check esm._ds?

Done.

xr.testing._assert_internal_invariants(esm._ds_time_computed)
computed_dset = esmlab.climatology(ds, freq='mon')
esmlab_res = computed_dset.drop(esm.static_variables).to_dataframe()
esmlab_res = computed_dset.to_dataframe()
esmlab_res = esmlab_res.groupby('month').mean()[variables]

df = (
esm._ds_time_computed.drop(esm.static_variables)
.to_dataframe()
Expand All @@ -171,11 +177,6 @@ def test_mon_climatology(ds_name, decoded, variables, time_coord_name):

assert_both_frames_equal(esmlab_res, pd_res)

assert computed_dset[esm.time_coord_name].dtype == ds[esm.time_coord_name].dtype
for key, value in ds[esm.time_coord_name].attrs.items():
assert key in computed_dset[esm.time_coord_name].attrs
assert value == computed_dset[esm.time_coord_name].attrs[key]


@pytest.mark.parametrize(
'ds_name, decoded, variables, time_coord_name, time_bound_name',
Expand All @@ -196,9 +197,10 @@ def test_mon_climatology_drop_time_bounds(
ds = ds.drop(ds_time_bound)
del ds[time_coord_name].attrs[time_bound_name]
esm = ds.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
computed_dset = esmlab.climatology(ds, freq='mon')

esmlab_res = computed_dset.drop(esm.static_variables).to_dataframe()
esmlab_res = computed_dset.to_dataframe()
esmlab_res = esmlab_res.groupby('month').mean()[variables]

df = (
Expand All @@ -213,11 +215,6 @@ def test_mon_climatology_drop_time_bounds(

assert_both_frames_equal(esmlab_res, pd_res)

assert computed_dset[esm.time_coord_name].dtype == ds[esm.time_coord_name].dtype
for key, value in ds[esm.time_coord_name].attrs.items():
assert key in computed_dset[esm.time_coord_name].attrs
assert value == computed_dset[esm.time_coord_name].attrs[key]


def test_anomaly_with_monthly_clim(dset):
computed_dset = esmlab.anomaly(dset, clim_freq='mon')
Expand Down Expand Up @@ -276,7 +273,6 @@ def test_resample_mon_mean(dset):
for method in {'left', 'right'}:
ds = esmlab.resample(dset, freq='mon', method=method)
assert len(ds.time) == 12

computed_dset = esmlab.resample(dset, freq='mon')
res = computed_dset.variable_1.data
expected = np.full(shape=(12, 2, 2), fill_value=0.5, dtype=np.float32)
Expand Down