Skip to content

Commit

Permalink
XArrayInterface now supports overriding of key and value dimensions (…
Browse files Browse the repository at this point in the history
…again) (#2542)
  • Loading branch information
drs251 authored and philippjfr committed Apr 10, 2018
1 parent bcf6d4e commit c33eb6c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
30 changes: 19 additions & 11 deletions holoviews/core/data/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,24 @@ def init(cls, eltype, data, kdims, vdims):
kdim_param = element_params['kdims']
vdim_param = element_params['vdims']

if isinstance (data, xr.DataArray):
if data.name:
vdim = Dimension(data.name)
elif vdims:
def retrieve_unit_and_label(dim):
if isinstance(dim, str):
dim = Dimension(dim)
dim.unit = data[dim.name].attrs.get('units')
label = data[dim.name].attrs.get('long_name')
if label is not None:
dim.label = label
return dim

if isinstance(data, xr.DataArray):
if vdims:
vdim = vdims[0]
elif data.name:
vdim = Dimension(data.name)
vdim.unit = data.attrs.get('units')
label = data.attrs.get('long_name')
if label is not None:
vdim.label = label
elif len(vdim_param.default) == 1:
vdim = vdim_param.default[0]
if vdim.name in data.dims:
Expand Down Expand Up @@ -111,6 +124,7 @@ def init(cls, eltype, data, kdims, vdims):
else:
if vdims is None:
vdims = list(data.data_vars.keys())
vdims = [retrieve_unit_and_label(vd) for vd in vdims]
if kdims is None:
xrdims = list(data.dims)
kdims = [name for name in data.indexes.keys()
Expand All @@ -121,6 +135,7 @@ def init(cls, eltype, data, kdims, vdims):
for c in data.coords:
if c not in kdims and set(data[c].dims) == set(virtual_dims):
kdims.append(c)
kdims = [retrieve_unit_and_label(kd) for kd in kdims]
vdims = [vd if isinstance(vd, Dimension) else Dimension(vd) for vd in vdims]
kdims = [kd if isinstance(kd, Dimension) else Dimension(kd) for kd in kdims]

Expand All @@ -136,13 +151,6 @@ def init(cls, eltype, data, kdims, vdims):
"for all defined kdims, %s coordinates not found."
% not_found, cls)

# retrieve units and labels from Dataset:
for d in kdims + vdims:
d.unit = data[d.name].attrs.get('units')
label = data[d.name].attrs.get('long_name')
if label is not None:
d.label = label

return data, {'kdims': kdims, 'vdims': vdims}, {}


Expand Down
23 changes: 23 additions & 0 deletions tests/core/data/testdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,29 @@ def test_xarray_dataset_dataarray_vs_dataset(self):
dataset_from_ds_rev = Dataset(ds_rev)
self.assertEqual(dataset_from_da_rev, dataset_from_ds_rev)

def test_xarray_override_dims(self):
import xarray as xr
xs = [0.1, 0.2, 0.3]
ys = [0, 1]
zs = np.array([[0, 1], [2, 3], [4, 5]])
da = xr.DataArray(zs, coords=[('x_dim', xs), ('y_dim', ys)], name="data_name", dims=['y_dim', 'x_dim'])
da.attrs['long_name'] = "data long name"
da.attrs['units'] = "array_unit"
da.x_dim.attrs['units'] = "x_unit"
da.y_dim.attrs['long_name'] = "y axis long name"
ds = Dataset(da, kdims=["x_dim", "y_dim"], vdims=["z_dim"])
x_dim = Dimension("x_dim")
y_dim = Dimension("y_dim")
z_dim = Dimension("z_dim")
self.assertEqual(ds.kdims[0], x_dim)
self.assertEqual(ds.kdims[1], y_dim)
self.assertEqual(ds.vdims[0], z_dim)
ds_from_ds = Dataset(da.to_dataset(), kdims=["x_dim", "y_dim"], vdims=["data_name"])
self.assertEqual(ds_from_ds.kdims[0], x_dim)
self.assertEqual(ds_from_ds.kdims[1], y_dim)
data_dim = Dimension("data_name")
self.assertEqual(ds_from_ds.vdims[0], data_dim)

def test_dataset_array_init_hm(self):
"Tests support for arrays (homogeneous)"
raise SkipTest("Not supported")
Expand Down

0 comments on commit c33eb6c

Please sign in to comment.