Skip to content

Commit

Permalink
Merge pull request #5334 from jenshnielsen/export_multigrid
Browse files Browse the repository at this point in the history
Export non gridded data to MultiIndex Xarray and compressed netcdf file
  • Loading branch information
jenshnielsen authored Sep 26, 2023
2 parents 37ffe7e + ac66224 commit c8085b5
Show file tree
Hide file tree
Showing 6 changed files with 1,545 additions and 284 deletions.
3 changes: 3 additions & 0 deletions docs/changes/newsfragments/5334.new
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
QCoDeS now exports data that isn't measured on a grid to a XArray using `MultiIndex`.
Support for exporting these datasets to NetCDF has also been implemented.
See `this notebook <../examples/DataSet/Working-With-Pandas-and-XArray.ipynb>`__ for additional details.
1,318 changes: 1,068 additions & 250 deletions docs/examples/DataSet/Working-With-Pandas-and-XArray.ipynb

Large diffs are not rendered by default.

19 changes: 16 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"opencensus>=0.7.10",
"opencensus-ext-azure>=1.0.4, <2.0.0",
"packaging>=20.0",
"pandas>=1.1.3",
"pandas>=1.2.0",
"pyvisa>=1.11.0, <1.14.0",
"ruamel.yaml>=0.16.0,!=0.16.6",
"tabulate>=0.8.0",
Expand All @@ -45,7 +45,8 @@ dependencies = [
"versioningit>=2.0.1",
"websockets>=9.1",
"wrapt>=1.13.2",
"xarray>=0.18.0",
"xarray>=2022.06.0",
"cf_xarray>=0.8.4",
"opentelemetry-api>=1.15.0",
# transitive dependencies. We list these explicitly to",
# ensure that we always use versions that do not have",
Expand Down Expand Up @@ -234,7 +235,8 @@ filterwarnings = [
# RUF200 validate pyproject.toml
# I isort
# ISC flake8-implicit-str-concat
select = ["E", "F", "PT025", "UP", "RUF200", "I", "G", "ISC"]
# TID253 banned-module-level-imports
select = ["E", "F", "PT025", "UP", "RUF200", "I", "G", "ISC", "TID253"]
# darker will fix this as code is
# reformatted when it is changed.
# We have a lot of use of f strings in log messages
Expand Down Expand Up @@ -283,6 +285,17 @@ extend-exclude = ["typings"]
# This triggeres in notebooks even with a md cell at the top
"*.ipynb" = ["E402"]

# these imports are fine at module level
# in tests and examples
"docs/*" = ["TID253"]
"qcodes/tests/*" = ["TID253"]

[tool.ruff.flake8-tidy-imports]
# There modules are relatively slow to import
# and only required in specific places so
# don't import them at module level
banned-module-level-imports = ["xarray", "pandas", "opencensus"]

[tool.setuptools]
zip-safe = false
include-package-data = false
Expand Down
34 changes: 27 additions & 7 deletions qcodes/dataset/data_set_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,13 @@ def _load_from_netcdf(
# in the code below floats and ints loaded from attributes are explicitly casted
# this is due to some older versions of qcodes writing them with a different backend
# reading them back results in a numpy array of one element

import cf_xarray as cfxr
import xarray as xr

loaded_data = xr.load_dataset(path, engine="h5netcdf")

loaded_data = cfxr.coding.decode_compress_to_multi_index(loaded_data)

parent_dataset_links = str_to_links(
loaded_data.attrs.get("parent_dataset_links", "[]")
)
Expand Down Expand Up @@ -393,12 +395,30 @@ def _from_xarray_dataset_to_qcodes_raw_data(
output[str(datavar)] = {}
data = xr_data[datavar]
output[str(datavar)][str(datavar)] = data.data
coords_unexpanded = []
for coord_name in data.dims:
coords_unexpanded.append(xr_data[coord_name].data)
coords_arrays = np.meshgrid(*coords_unexpanded, indexing="ij")
for coord_name, coord_array in zip(data.dims, coords_arrays):
output[str(datavar)][str(coord_name)] = coord_array

all_coords = []
for index_name in data.dims:
index = data.indexes[index_name]

coords = {name: data.coords[name] for name in index.names}
all_coords.append(coords)

if len(all_coords) > 1:
# if there are more than on index this cannot be a multiindex dataset
# so we can expand the data
coords_unexpanded = []
for coord_name in data.dims:
coords_unexpanded.append(xr_data[coord_name].data)
coords_arrays = np.meshgrid(*coords_unexpanded, indexing="ij")
for coord_name, coord_array in zip(data.dims, coords_arrays):
output[str(datavar)][str(coord_name)] = coord_array
elif len(all_coords) == 1:
# this is either a multiindex or a single regular index
# in both cases we do not need to reshape the data
coords = all_coords[0]
for coord_name, coord in coords.items():
output[str(datavar)][str(coord_name)] = coord.data

return output

def prepare(
Expand Down
98 changes: 84 additions & 14 deletions qcodes/dataset/exporters/export_to_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Hashable, Mapping
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, cast

Expand All @@ -17,11 +18,44 @@
)

if TYPE_CHECKING:
import pandas as pd
import xarray as xr

from qcodes.dataset.data_set_protocol import DataSetProtocol, ParameterData


def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]:
# heavily inspired by xarray.core.dataset.from_dataframe
import pandas as pd
from xarray.core.indexes import PandasIndex, remove_unused_levels_categories
from xarray.core.variable import Variable, calculate_dimensions

idx = remove_unused_levels_categories(idx)

if isinstance(idx, pd.MultiIndex) and not idx.is_unique:
raise ValueError(
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)
index_vars: dict[Hashable, Variable] = {}

if isinstance(idx, pd.MultiIndex):
dims = tuple(
name if name is not None else f"level_{n}"
for n, name in enumerate(idx.names)
)
for dim, lev in zip(dims, idx.levels):
xr_idx = PandasIndex(lev, dim)
index_vars.update(xr_idx.create_variables())
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
xr_idx = PandasIndex(idx, index_name)
index_vars.update(xr_idx.create_variables())

expanded_shape = calculate_dimensions(index_vars)
return expanded_shape


def _load_to_xarray_dataarray_dict_no_metadata(
dataset: DataSetProtocol, datadict: Mapping[str, Mapping[str, np.ndarray]]
) -> dict[str, xr.DataArray]:
Expand All @@ -31,15 +65,37 @@ def _load_to_xarray_dataarray_dict_no_metadata(

for name, subdict in datadict.items():
index = _generate_pandas_index(subdict)
if index is not None and len(index.unique()) != len(index):
for _name in subdict:
data_xrdarray_dict[_name] = _data_to_dataframe(
subdict, index).reset_index().to_xarray()[_name]
else:

if index is None:
xrdarray: xr.DataArray = (
_data_to_dataframe(subdict, index).to_xarray().get(name, xr.DataArray())
_data_to_dataframe(subdict, index=index)
.to_xarray()
.get(name, xr.DataArray())
)
data_xrdarray_dict[name] = xrdarray
else:
index_unique = len(index.unique()) == len(index)

df = _data_to_dataframe(subdict, index)

if not index_unique:
# index is not unique so we fallback to using a counter as index
# and store the index as a variable
xrdata_temp = df.reset_index().to_xarray()
for _name in subdict:
data_xrdarray_dict[_name] = xrdata_temp[_name]
else:
calc_index = _calculate_index_shape(index)
index_prod = prod(calc_index.values())
# if the product of the len of individual index dims == len(total_index)
# we are on a grid
on_grid = index_prod == len(index)
if not on_grid:
xrdarray = xr.DataArray(df[name], [("multi_index", df.index)])
else:
xrdarray = df.to_xarray().get(name, xr.DataArray())

data_xrdarray_dict[name] = xrdarray

return data_xrdarray_dict

Expand Down Expand Up @@ -115,7 +171,7 @@ def _add_param_spec_to_xarray_coords(
dataset: DataSetProtocol, xrdataset: xr.Dataset | xr.DataArray
) -> None:
for coord in xrdataset.coords:
if coord != "index":
if coord not in ("index", "multi_index"):
paramspec_dict = _paramspec_dict_with_extras(dataset, str(coord))
xrdataset.coords[str(coord)].attrs.update(paramspec_dict.items())

Expand Down Expand Up @@ -143,13 +199,27 @@ def _paramspec_dict_with_extras(
def xarray_to_h5netcdf_with_complex_numbers(
xarray_dataset: xr.Dataset, file_path: str | Path
) -> None:
import cf_xarray as cfxr
from pandas import MultiIndex

has_multi_index = any(
isinstance(xarray_dataset.indexes[index_name], MultiIndex)
for index_name in xarray_dataset.indexes
)

if has_multi_index:
# as of xarray 2023.8.0 there is no native support
# for multi index so use cf_xarray for that
internal_ds = cfxr.coding.encode_multi_index_as_compress(
xarray_dataset,
)
else:
internal_ds = xarray_dataset

data_var_kinds = [
xarray_dataset.data_vars[data_var].dtype.kind
for data_var in xarray_dataset.data_vars
]
coord_kinds = [
xarray_dataset.coords[coord].dtype.kind for coord in xarray_dataset.coords
internal_ds.data_vars[data_var].dtype.kind for data_var in internal_ds.data_vars
]
coord_kinds = [internal_ds.coords[coord].dtype.kind for coord in internal_ds.coords]
if "c" in data_var_kinds or "c" in coord_kinds:
# see http://xarray.pydata.org/en/stable/howdoi.html
# for how to export complex numbers
Expand All @@ -160,8 +230,8 @@ def xarray_to_h5netcdf_with_complex_numbers(
message="You are writing invalid netcdf features",
category=UserWarning,
)
xarray_dataset.to_netcdf(
internal_ds.to_netcdf(
path=file_path, engine="h5netcdf", invalid_netcdf=True
)
else:
xarray_dataset.to_netcdf(path=file_path, engine="h5netcdf")
internal_ds.to_netcdf(path=file_path, engine="h5netcdf")
Loading

0 comments on commit c8085b5

Please sign in to comment.