From 0dd6946a5b7642bdd4f6da1859f5d5bfa26aa14e Mon Sep 17 00:00:00 2001 From: Ingmar Schoegl Date: Fri, 9 Aug 2019 17:03:37 -0500 Subject: [PATCH] [Thermo] recreation of SolutionArray objects from stored data This commit implements new methods for SolutionArray: * `restore_data` can restore data previously exported by `collect_data` * `read_csv` restores data previously saved by `write_csv` * unit tests are added --- interfaces/cython/cantera/composite.py | 124 +++++++++++++++ interfaces/cython/cantera/test/test_thermo.py | 141 ++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/interfaces/cython/cantera/composite.py b/interfaces/cython/cantera/composite.py index 089b56bb409..6a8d48ae022 100644 --- a/interfaces/cython/cantera/composite.py +++ b/interfaces/cython/cantera/composite.py @@ -565,6 +565,117 @@ def equilibrate(self, *args, **kwargs): self._phase.equilibrate(*args, **kwargs) self._states[index][:] = self._phase.state + def restore_data(self, data, labels): + """ + Restores a SolutionArray based on *data* specified in a single + 2D Numpy array and a list of corresponding column *labels*. Thus, + this method allows to restore data exported by `collect_data`. + + :param data: a 2D Numpy array holding data to be restored. + :param labels: a list of labels corresponding to SolutionArray entries. + + The receiving SolutionArray either has to be empty or should have + matching dimensions. Essential state properties and extra entries + are detected automatically whereas stored information of calculated + properties is omitted. If the receiving SolutionArray has extra + entries already specified, only those will be imported; if *labels* does + not contain those entries, an error is raised. + """ + + # check arguments + if not isinstance(data, np.ndarray) or data.ndim != 2: + raise TypeError("restore_data only works for 2D ndarrays") + elif len(labels) != data.shape[1]: + raise ValueError("inconsistent data and label dimensions") + rows = data.shape[0] + if self._shape!=(0,) and self._shape!=(rows,): + raise ValueError('incompatible dimensions.') + + # get full state information (may differ depending on type of ThermoPhase) + full_states = [fs for fs in self._phase._full_states.values()] + if isinstance(self._phase, PureFluid): + # make sure that potentially non-unique state definitions are checked last + last = ['TP', 'TX', 'PX'] + full_states = [fs for fs in full_states + if fs not in last] + ['TPX'] + last + + # determine whether complete concentration is available (mass or mole) + # assumes that `X` or `Y` is always in last place + mode = '' + has_species = False + for prefix in ['X_', 'Y_']: + spc = ['{}{}'.format(prefix, s) for s in self.species_names] + valid_species = {s[2:]: labels.index(s) for s in spc + if s in labels} + all_species = [l for l in labels if l[:2] == prefix] + if len(valid_species): + mode = prefix[0] + full_states = [v[:-1] for v in full_states if mode in v] + break + if len(valid_species) != len(all_species): + raise ValueError('incompatible species information.') + if mode == '': + full_states = {v[:2] for v in full_states} + + # determine suitable thermo properties for reconstruction + basis = {'molar': 'mole', 'mass': 'mass'}[self.basis] + prop = {'T': ('T'), 'P': ('P'), + 'D': ('density', 'density_{}'.format(basis)), + 'U': ('u', 'int_energy_{}'.format(basis)), + 'V': ('v', 'volume_{}'.format(basis)), + 'H': ('h', 'enthalpy_{}'.format(basis)), + 'S': ('s', 'entropy_{}'.format(basis))} + for fs in full_states: + state = [{fs[i]: labels.index(p) for p in prop[fs[i]] if p in labels} + for i in range(len(fs))] + found = [len(state[i]) for i in range(len(fs))] + if all(found): + mode = fs + mode + break + if len(mode) == 1: + raise ValueError('invalid/incomplete state information.') + + # raise warning if state is potentially not uniquely defined + if isinstance(self._phase, PureFluid) and mode in last: + # note: adding a setter for PureFluid.TPX would would be beneficial + warnings.warn('Using mode `{}` to restore data: may not ' + 'be sufficient to define unique state ' + 'for a PureFluid phase'.format(mode), + UserWarning) + + # assemble and restore state information + state_data = tuple([data[:, state[i][mode[i]]] for i in range(len(state))]) + if len(valid_species): + state_data += (np.zeros((rows, self.n_species)),) + for i, s in enumerate(self.species_names): + if s in valid_species: + state_data[-1][:, i] = data[:, valid_species[s]] + + # labels may include calculated properties that must not be restored + calculated = self._scalar + self._n_species + self._n_reactions + exclude = [l for l in labels + if any([v in l for v in calculated])] + extra = {l: list(data[:, i]) for i, l in enumerate(labels) + if l not in exclude} + if len(self._extra_lists): + extra_lists = {k: extra[k] for k in self._extra_arrays} + else: + extra_lists = extra + + # ensure that SolutionArray accommodates dimensions + if self._shape == (0,): + self._states = [self._phase.state] * rows + self._indices = range(rows) + self._output_dummy = self._indices + self._shape = (rows,) + + # restore data + for i in self._indices: + setattr(self._phase, mode, [st[i, ...] for st in state_data]) + self._states[i] = self._phase.state + self._extra_lists = extra_lists + self._extra_arrays = {l: np.array(v) for l, v in extra_lists.items()} + def set_equivalence_ratio(self, phi, *args, **kwargs): """ See `ThermoPhase.set_equivalence_ratio` @@ -670,6 +781,19 @@ def write_csv(self, filename, cols=('extra','T','density','Y'), for row in data: writer.writerow(row) + def read_csv(self, filename): + """ + Read a CSV file named *filename* and restore data to the SolutionArray + using `restore_data`. This method allows for recreation of data + previously exported by `write_csv`. + """ + # read data block and header separately + data = np.genfromtxt(filename, skip_header=1, delimiter=',') + labels = np.genfromtxt(filename, + max_rows=1, delimiter=',', dtype=str) + + self.restore_data(data, list(labels)) + def _make_functions(): # this is wrapped in a function to avoid polluting the module namespace diff --git a/interfaces/cython/cantera/test/test_thermo.py b/interfaces/cython/cantera/test/test_thermo.py index b91a988fe53..3f0e5617c81 100644 --- a/interfaces/cython/cantera/test/test_thermo.py +++ b/interfaces/cython/cantera/test/test_thermo.py @@ -3,6 +3,7 @@ import os import numpy as np import gc +import warnings import cantera as ct from . import utilities @@ -1520,3 +1521,143 @@ def test_write_csv(self): self.assertEqual(len(data), 7) self.assertEqual(len(data.dtype), self.gas.n_species + 2) self.assertIn('Y_H2', data.dtype.fields) + + b = ct.SolutionArray(self.gas) + b.read_csv(outfile) + self.assertTrue(np.allclose(states.T, b.T)) + self.assertTrue(np.allclose(states.P, b.P)) + self.assertTrue(np.allclose(states.X, b.X)) + + def test_restore(self): + + def check(a, b, atol=None): + if atol is None: + fcn = lambda c, d: np.allclose(c, d) + else: + fcn = lambda c, d: np.allclose(c, d, atol=atol) + check = fcn(a.T, b.T) + check &= fcn(a.P, b.P) + check &= fcn(a.X, b.X) + return check + + # test ThermoPhase + a = ct.SolutionArray(self.gas) + for i in range(10): + T = 300 + 1800*np.random.random() + P = ct.one_atm*(1 + 10*np.random.random()) + X = np.random.random(self.gas.n_species) + X[-1] = 0. + X /= X.sum() + a.append(T=T, P=P, X=X) + + data, labels = a.collect_data() + + # basic restore + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + + # skip concentrations + b = ct.SolutionArray(self.gas) + b.restore_data(data[:, :2], labels[:2]) + self.assertTrue(np.allclose(a.T, b.T)) + self.assertTrue(np.allclose(a.density, b.density)) + self.assertFalse(np.allclose(a.X, b.X)) + + # wrong data shape + b = ct.SolutionArray(self.gas) + with self.assertRaises(TypeError): + b.restore_data(data.ravel(), labels) + + # inconsistent data + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data, labels[:-2]) + + # inconsistent shape of receiving SolutionArray + b = ct.SolutionArray(self.gas, 9) + with self.assertRaises(ValueError): + b.restore_data(data, labels) + + # incomplete state + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data[:,1:], labels[1:]) + + # add extra column + t = np.arange(10, dtype=float)[:, np.newaxis] + + # auto-detection of extra + b = ct.SolutionArray(self.gas) + b.restore_data(np.hstack([t, data]), ['time'] + labels) + self.assertTrue(check(a, b)) + + # explicit extra + b = ct.SolutionArray(self.gas, extra=('time',)) + b.restore_data(np.hstack([t, data]), ['time'] + labels) + self.assertTrue(check(a, b)) + self.assertTrue((b.time == t.ravel()).all()) + + # wrong extra + b = ct.SolutionArray(self.gas, extra=('xyz',)) + with self.assertRaises(KeyError): + b.restore_data(np.hstack([t, data]), ['time'] + labels) + + # missing extra + b = ct.SolutionArray(self.gas, extra=('time')) + with self.assertRaises(KeyError): + b.restore_data(data, labels) + + # inconsistent species + labels[-1] = 'Y_invalid' + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data, labels) + + # incomplete species info (using threshold) + data, labels = a.collect_data(threshold=1e-6) + + # basic restore + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b, atol=1e-6)) + + # skip calculated properties + cols = ('T', 'P', 'X', 'gibbs_mass', 'forward_rates_of_progress') + data, labels = a.collect_data(cols=cols, threshold=1e-6) + + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + self.assertTrue(len(b._extra_arrays) == 0) + + # test PureFluid + w = ct.Water() + a = ct.SolutionArray(w, 10) + a.TX = 373.15, np.linspace(0., 1., 10) + + # complete data + cols = ('T', 'P', 'X') + data, labels = a.collect_data(cols=cols) + + b = ct.SolutionArray(w) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + + # partial data + cols = ('T', 'X') + data, labels = a.collect_data(cols=cols) + + with warnings.catch_warnings(record=True) as warn: + + # cause all warnings to always be triggered. + warnings.simplefilter("always") + + b = ct.SolutionArray(w) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + + self.assertTrue(len(warn) == 1) + self.assertTrue(issubclass(warn[-1].category, UserWarning)) + self.assertTrue("may not be sufficient to define unique state" + in str(warn[-1].message))