Skip to content

Commit

Permalink
[Thermo] add ability to sort SolutionArray objects
Browse files Browse the repository at this point in the history
  • Loading branch information
ischoegl authored and Ingmar Schoegl committed Aug 10, 2019
1 parent 97356a4 commit e9a4629
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
19 changes: 18 additions & 1 deletion interfaces/cython/cantera/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,12 @@ def __init__(self, phase, shape=(0,), states=None, extra=None):
shape = (shape,)

if states is not None:
self._shape = np.shape(states)[:-1]
if isinstance(states, list):
self._shape = (len(states),)
elif isinstance(states, np.ndarray):
self._shape = np.shape(states)[:-1]
else:
raise TypeError('invalid type')
self._states = states
else:
self._shape = tuple(shape)
Expand Down Expand Up @@ -558,6 +563,18 @@ def append(self, state=None, **kwargs):
self._indices.append(len(self._indices))
self._shape = (len(self._indices),)

def sort(self, col):
""" Sort SolutionArray by column *col*. """
if len(self._shape) != 1:
raise TypeError("sort only works for 1D SolutionArray objects")

indices = np.argsort(getattr(self, col))
self._states = [self._states[ix] for ix in indices]
for k, v in self._extra_arrays.items():
new = v[indices]
self._extra_arrays[k] = new
self._extra_lists[k] = list(new)

def equilibrate(self, *args, **kwargs):
""" See `ThermoPhase.equilibrate` """
for index in self._indices:
Expand Down
19 changes: 19 additions & 0 deletions interfaces/cython/cantera/test/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,25 @@ def test_purefluid(self):
states.TP = np.linspace(400, 500, 5), 101325
self.assertArrayNear(states.X.squeeze(), np.ones(5))

def test_sort(self):
np.random.seed(0)
t = np.random.random(101)
T = np.linspace(300.,1000.,101)
P = ct.one_atm * (1. + 10.*np.random.random(101))

states = ct.SolutionArray(self.gas, 101, extra={'t': t})
states.TP = T, P

states.sort('t')
self.assertTrue((states.t[1:]-states.t[:-1]>0).all())
self.assertFalse((states.T[1:]-states.T[:-1]>0).all())
self.assertFalse(np.allclose(states.P,P))

states.sort('T')
self.assertFalse((states.t[1:]-states.t[:-1]>0).all())
self.assertTrue((states.T[1:]-states.T[:-1]>0).all())
self.assertTrue(np.allclose(states.P,P))

def test_set_equivalence_ratio(self):
states = ct.SolutionArray(self.gas, 8)
phi = np.linspace(.5, 2., 8)
Expand Down

0 comments on commit e9a4629

Please sign in to comment.