From 6519516a084fd8d334daa6c43cd020a0d73fec54 Mon Sep 17 00:00:00 2001 From: Ingmar Schoegl Date: Thu, 8 Aug 2019 09:51:35 -0500 Subject: [PATCH] [Thermo] implement recreation of SolutionArray objects from stored data This commit implements methods for SolutionArray: * `restore_data` can restore data previously exported by `collect_data` * `read_csv` restores data previously saved by `write_csv` --- interfaces/cython/cantera/composite.py | 105 +++++++++++++++++ interfaces/cython/cantera/test/test_thermo.py | 108 ++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/interfaces/cython/cantera/composite.py b/interfaces/cython/cantera/composite.py index 938a3eb6c8f..3109c58a8a3 100644 --- a/interfaces/cython/cantera/composite.py +++ b/interfaces/cython/cantera/composite.py @@ -565,6 +565,98 @@ def equilibrate(self, *args, **kwargs): self._phase.equilibrate(*args, **kwargs) self._states[index][:] = self._phase.state + def restore_data(self, data, labels): + """ + Restores a SolutionArray based on *data* specified in a single + 2D Numpy array and a list of corresponding column *labels*. Thus, + this method allows to restore data exported by `collect_data`. + + :param data: a 2D Numpy array holding data to be restored. + :param labels: a list of labels corresponding to SolutionArray entries. + + The receiving SolutionArray either has to be empty or should have + matching dimensions. Essential state properties and extra entries + are detected automatically whereas stored information of calculated + properties is omitted. If the receiving SolutionArray has extra + entries already specified, only those will be imported; if *labels* does + not contain those entries, an error is raised. + """ + + # check arguments + if not isinstance(data, np.ndarray) or data.ndim != 2: + raise TypeError("restore_data only works for 2D ndarrays") + elif len(labels) != data.shape[1]: + raise ValueError("inconsistent data and label dimensions") + rows = data.shape[0] + if self._shape!=(0,) and self._shape!=(rows,): + raise ValueError('incompatible dimensions.') + + # determine whether complete concentration is available (mass or mole) + mode = '' + has_species = False + for prefix in ['X_', 'Y_']: + spc = ['{}{}'.format(prefix, s) for s in self.species_names] + valid_species = {s[2:]: labels.index(s) for s in spc + if s in labels} + all_species = [l for l in labels if l[:2] == prefix] + if len(valid_species)>0: + mode = prefix[0] + break + if len(valid_species) != len(all_species): + raise ValueError('incompatible species information.') + + # determine suitable thermo properties for reconstruction + basis = {'molar': 'mole', 'mass': 'mass'}[self.basis] + prop = {'T': ('T'), 'P': ('P'), + 'D': ('density', 'density_{}'.format(basis)), + 'U': ('u', 'int_energy_{}'.format(basis)), + 'V': ('v', 'volume_{}'.format(basis)), + 'H': ('h', 'enthalpy_{}'.format(basis)), + 'S': ('s', 'entropy_{}'.format(basis))} + for st2 in self._state2: + state = [{st2[i]: labels.index(p) for p in prop[st2[i]] if p in labels} + for i in range(2)] + if len(state[0]) and len(state[1]): + mode = st2 + mode + break + if len(mode) == 1: + raise ValueError('invalid/incomplete state information.') + + # assemble and restore state information + state_data = data[:, state[0][mode[0]]], data[:, state[1][mode[1]]] + if len(mode)>2: + + # add species data + state_data += (np.zeros((rows, self.n_species)),) + for i, s in enumerate(self.species_names): + if s in valid_species: + state_data[2][:, i] = data[:, valid_species[s]] + + # labels may include calculated properties that must not be restored + calculated = self._scalar + self._n_species + self._n_reactions + exclude = [l for l in labels + if any([v in l for v in calculated])] + extra = {l: data[:, i] for i, l in enumerate(labels) + if l not in exclude} + if len(self._extra_arrays): + extra_arrays = {k: extra[k] for k in self._extra_arrays} + else: + extra_arrays = extra + + # ensure that SolutionArray accommodates dimensions + if self._shape == (0,): + self._states = [self._phase.state] * rows + self._indices = range(rows) + self._output_dummy = self._indices + self._shape = (rows,) + + # restore data + for i in self._indices: + setattr(self._phase, mode, [st[i, ...] for st in state_data]) + self._states[i] = self._phase.state + self._extra_arrays = extra_arrays + self._extra_lists = {l: list(v) for l, v in extra_arrays.items()} + def collect_data(self, cols=('extra','T','density','Y'), threshold=0, species='Y'): """ @@ -653,6 +745,19 @@ def write_csv(self, filename, cols=('extra','T','density','Y'), for row in data: writer.writerow(row) + def read_csv(self, filename): + """ + Read a CSV file named *filename* and restore data to the SolutionArray + using `restore_data`. This method allows for recreation of data + previously exported by `write_csv`. + """ + # read data block and header separately + data = np.genfromtxt(filename, skip_header=1, delimiter=',') + labels = np.genfromtxt(filename, + max_rows=1, delimiter=',', dtype=str) + + self.restore_data(data, list(labels)) + def _make_functions(): # this is wrapped in a function to avoid polluting the module namespace diff --git a/interfaces/cython/cantera/test/test_thermo.py b/interfaces/cython/cantera/test/test_thermo.py index a870581a518..5690a22c93d 100644 --- a/interfaces/cython/cantera/test/test_thermo.py +++ b/interfaces/cython/cantera/test/test_thermo.py @@ -1497,3 +1497,111 @@ def test_write_csv(self): self.assertEqual(len(data), 7) self.assertEqual(len(data.dtype), self.gas.n_species + 2) self.assertIn('Y_H2', data.dtype.fields) + + b = ct.SolutionArray(self.gas) + b.read_csv(outfile) + self.assertTrue(np.allclose(states.T, b.T)) + self.assertTrue(np.allclose(states.P, b.P)) + self.assertTrue(np.allclose(states.X, b.X)) + + def test_restore(self): + + a = ct.SolutionArray(self.gas) + for i in range(10): + T = 300 + 1800*np.random.random() + P = ct.one_atm*(1 + 10*np.random.random()) + X = np.random.random(self.gas.n_species) + X[-1] = 0. + X /= X.sum() + a.append(T=T, P=P, X=X) + + def check(a, b, atol=None): + if atol is None: + fcn = lambda c, d: np.allclose(c, d) + else: + fcn = lambda c, d: np.allclose(c, d, atol=atol) + check = fcn(a.T, b.T) + check &= fcn(a.P, b.P) + check &= fcn(a.X, b.X) + return check + + data, labels = a.collect_data() + + # basic restore + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + + # skip concentrations + b = ct.SolutionArray(self.gas) + b.restore_data(data[:, :2], labels[:2]) + self.assertTrue(np.allclose(a.T, b.T)) + self.assertTrue(np.allclose(a.density, b.density)) + self.assertFalse(np.allclose(a.X, b.X)) + + # wrong data shape + b = ct.SolutionArray(self.gas) + with self.assertRaises(TypeError): + b.restore_data(data.ravel(), labels) + + # inconsistent data + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data, labels[:-2]) + + # inconsistent shape of receiving SolutionArray + b = ct.SolutionArray(self.gas, 9) + with self.assertRaises(ValueError): + b.restore_data(data, labels) + + # incomplete state + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data[:,1:], labels[1:]) + + # add extra column + t = np.arange(10, dtype=float)[:, np.newaxis] + + # auto-detection of extra + b = ct.SolutionArray(self.gas) + b.restore_data(np.hstack([t, data]), ['time'] + labels) + self.assertTrue(check(a, b)) + + # explicit extra + b = ct.SolutionArray(self.gas, extra=('time',)) + b.restore_data(np.hstack([t, data]), ['time'] + labels) + self.assertTrue(check(a, b)) + self.assertTrue((b.time == t.ravel()).all()) + + # wrong extra + b = ct.SolutionArray(self.gas, extra=('xyz',)) + with self.assertRaises(KeyError): + b.restore_data(np.hstack([t, data]), ['time'] + labels) + + # missing extra + b = ct.SolutionArray(self.gas, extra=('time')) + with self.assertRaises(KeyError): + b.restore_data(data, labels) + + # inconsistent species + labels[-1] = 'Y_invalid' + b = ct.SolutionArray(self.gas) + with self.assertRaises(ValueError): + b.restore_data(data, labels) + + # incomplete species info (using threshold) + data, labels = a.collect_data(threshold=1e-6) + + # basic restore + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b, atol=1e-6)) + + # skip calculated properties + cols = ('T', 'P', 'X', 'gibbs_mass', 'forward_rates_of_progress') + data, labels = a.collect_data(cols=cols, threshold=1e-6) + + b = ct.SolutionArray(self.gas) + b.restore_data(data, labels) + self.assertTrue(check(a, b)) + self.assertTrue(len(b._extra_arrays) == 0)