Skip to content
This repository has been archived by the owner on Apr 30, 2021. It is now read-only.

Commit

Permalink
Merge pull request #156 from andersy005/master
Browse files Browse the repository at this point in the history
Avoid setting the `.data` attribute
  • Loading branch information
andersy005 authored Oct 28, 2019
2 parents 35ce76b + e5822a6 commit 532e940
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 36 deletions.
33 changes: 11 additions & 22 deletions esmlab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ def compute_time(self):
else:
groupby_coord = self.get_time_decoded(midpoint=False)

ds[self.time_coord_name].data = groupby_coord.data
ds[self.time_coord_name] = groupby_coord.data

if self.time_bound is not None:
ds[self.tb_name].data = self.time_bound.data
self.time_bound[self.time_coord_name].data = groupby_coord.data
self.time_bound[self.time_coord_name] = groupby_coord.data
self.time_bound_diff = self.compute_time_bound_diff(ds)

self._ds_time_computed = ds
Expand All @@ -101,6 +100,7 @@ def compute_time_bound_diff(self, ds):
if self.time_bound is not None:
time_bound_diff.name = self.tb_name + '_diff'
time_bound_diff.attrs = {}
# Compute
time_bound_diff.data = self.time_bound.diff(dim=self.tb_dim)[:, 0]
if self.tb_dim in time_bound_diff.coords:
time_bound_diff = time_bound_diff.drop(self.tb_dim)
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_original_metadata(self):
v: {
key: val
for key, val in self._ds[v].encoding.items()
if key in ['dtype', '_FillValue', 'missing_value']
if key in ['dtype', '_FillValue', 'missing_value', 'units', 'calendar']
}
for v in self._ds.variables
}
Expand Down Expand Up @@ -440,7 +440,7 @@ def compute_resample_times(self, ds, temporary_time_coord_name, time_dot, method
else:
time_values = self._ds_time_computed[self.time_coord_name].groupby(time_dot).mean()

ds[self.time_coord_name].data = time_values.data
ds[self.time_coord_name] = time_values.data

return ds

Expand All @@ -453,24 +453,13 @@ def compute_mon_climatology(self):
self._ds_time_computed.drop(self.static_variables)
.groupby(time_dot_month)
.mean(self.time_coord_name)
.rename({'month': self.time_coord_name})
)
computed_dset['month'] = computed_dset[self.time_coord_name].copy()
attrs = {'month': {'long_name': 'Month', 'units': 'month'}}
encoding = {
'month': {'dtype': 'int32', '_FillValue': None},
self.time_coord_name: {'dtype': 'float', '_FillValue': None},
}

if self.time_bound is not None:
time_data = computed_dset[self.tb_name] - computed_dset[self.tb_name][0, 0]
computed_dset[self.tb_name] = time_data
computed_dset[self.time_coord_name].data = (
computed_dset[self.tb_name].mean(self.tb_dim).data
)
encoding[self.tb_name] = {'dtype': 'float', '_FillValue': None}
encoding = {'month': {'dtype': 'int32', '_FillValue': None}}

return self.restore_dataset(computed_dset, attrs=attrs, encoding=encoding)
if self.tb_name in computed_dset.data_vars:
computed_dset = computed_dset.drop(self.tb_name)
return self.update_metadata(computed_dset, attrs, encoding)

@esmlab_xr_set_options(arithmetic_join='exact')
def compute_mon_anomaly(self, slice_mon_clim_time=None):
Expand Down Expand Up @@ -504,9 +493,9 @@ def compute_mon_anomaly(self, slice_mon_clim_time=None):
# reset month to become a variable
computed_dset = computed_dset.reset_coords('month')

computed_dset[self.time_coord_name].data = self.time.data
computed_dset[self.time_coord_name] = self.time
if self.time_bound is not None:
computed_dset[self.tb_name].data = self.time_bound.data
computed_dset[self.tb_name] = self.time_bound

attrs = {'month': {'long_name': 'Month'}}
return self.restore_dataset(computed_dset, attrs=attrs)
Expand Down
24 changes: 10 additions & 14 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_esmlab_accessor():
attrs = {'calendar': 'noleap', 'units': 'days since 2000-01-01 00:00:00'}
ds.time.attrs = attrs
esm = ds.esmlab.set_time(time_coord_name='time')
xr.testing._assert_internal_invariants(esm._ds_time_computed)
# Time and Time bound Attributes
expected = dict(esm.time_attrs)
attrs['bounds'] = None
Expand Down Expand Up @@ -96,6 +97,7 @@ def test_datetime_cftime_exception():

def test_time_bound_var(dset, time_coord_name='time'):
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.tb_name, esm.tb_dim
expected = ('time_bound', 'd2')
assert results == expected
Expand All @@ -114,6 +116,7 @@ def test_get_time_attrs(dset, time_coord_name='time'):
'bounds': 'time_bound',
}
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.time_attrs
assert results == expected

Expand All @@ -122,12 +125,14 @@ def test_compute_time_var(dset, time_coord_name='time'):
idx = dset.indexes[time_coord_name]
assert isinstance(idx, pd.core.indexes.numeric.Index)
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
results = esm.get_time_decoded()
assert isinstance(results, xr.DataArray)


def test_uncompute_time_var(dset, time_coord_name='time'):
esm = dset.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
ds = esm.compute_time_var()
assert ds[time_coord_name].dtype == np.dtype('O')
dset_with_uncomputed_time = esm.uncompute_time_var()
Expand All @@ -137,6 +142,7 @@ def test_uncompute_time_var(dset, time_coord_name='time'):
# For some strange reason, this case fails when using pytest parametrization
def test_sel_time_(dset):
esm = dset.esmlab.set_time()
xr.testing._assert_internal_invariants(esm._ds_time_computed)
dset = esm.sel_time(indexer_val=slice('1850-01-01', '1850-12-31'), year_offset=1850)
assert len(dset.time) == 12

Expand All @@ -155,10 +161,10 @@ def test_sel_time_(dset):
def test_mon_climatology(ds_name, decoded, variables, time_coord_name):
ds = esmlab.datasets.open_dataset(ds_name, decode_times=decoded)
esm = ds.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
computed_dset = esmlab.climatology(ds, freq='mon')
esmlab_res = computed_dset.drop(esm.static_variables).to_dataframe()
esmlab_res = computed_dset.to_dataframe()
esmlab_res = esmlab_res.groupby('month').mean()[variables]

df = (
esm._ds_time_computed.drop(esm.static_variables)
.to_dataframe()
Expand All @@ -171,11 +177,6 @@ def test_mon_climatology(ds_name, decoded, variables, time_coord_name):

assert_both_frames_equal(esmlab_res, pd_res)

assert computed_dset[esm.time_coord_name].dtype == ds[esm.time_coord_name].dtype
for key, value in ds[esm.time_coord_name].attrs.items():
assert key in computed_dset[esm.time_coord_name].attrs
assert value == computed_dset[esm.time_coord_name].attrs[key]


@pytest.mark.parametrize(
'ds_name, decoded, variables, time_coord_name, time_bound_name',
Expand All @@ -196,9 +197,10 @@ def test_mon_climatology_drop_time_bounds(
ds = ds.drop(ds_time_bound)
del ds[time_coord_name].attrs[time_bound_name]
esm = ds.esmlab.set_time(time_coord_name=time_coord_name)
xr.testing._assert_internal_invariants(esm._ds_time_computed)
computed_dset = esmlab.climatology(ds, freq='mon')

esmlab_res = computed_dset.drop(esm.static_variables).to_dataframe()
esmlab_res = computed_dset.to_dataframe()
esmlab_res = esmlab_res.groupby('month').mean()[variables]

df = (
Expand All @@ -213,11 +215,6 @@ def test_mon_climatology_drop_time_bounds(

assert_both_frames_equal(esmlab_res, pd_res)

assert computed_dset[esm.time_coord_name].dtype == ds[esm.time_coord_name].dtype
for key, value in ds[esm.time_coord_name].attrs.items():
assert key in computed_dset[esm.time_coord_name].attrs
assert value == computed_dset[esm.time_coord_name].attrs[key]


def test_anomaly_with_monthly_clim(dset):
computed_dset = esmlab.anomaly(dset, clim_freq='mon')
Expand Down Expand Up @@ -276,7 +273,6 @@ def test_resample_mon_mean(dset):
for method in {'left', 'right'}:
ds = esmlab.resample(dset, freq='mon', method=method)
assert len(ds.time) == 12

computed_dset = esmlab.resample(dset, freq='mon')
res = computed_dset.variable_1.data
expected = np.full(shape=(12, 2, 2), fill_value=0.5, dtype=np.float32)
Expand Down

0 comments on commit 532e940

Please sign in to comment.