Skip to content

Commit

Permalink
Eliminate SolutionArray._extra_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingmar Schoegl committed Aug 11, 2019
1 parent 2f45956 commit 2728a6f
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions interfaces/cython/cantera/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,23 +471,19 @@ def __init__(self, phase, shape=(0,), states=None, extra=None):
self._output_dummy = self._states[..., 0]

self._extra_lists = {}
self._extra_arrays = {}
if isinstance(extra, dict):
for name, v in extra.items():
if not np.shape(v):
self._extra_lists[name] = [v]*self._shape[0]
self._extra_arrays[name] = np.array(self._extra_lists[name])
elif len(v) == self._shape[0]:
self._extra_lists[name] = list(v)
else:
raise ValueError("Unable to map extra SolutionArray"
"input for named {!r}".format(name))
self._extra_arrays[name] = np.array(self._extra_lists[name])

elif extra and self._shape == (0,):
for name in extra:
self._extra_lists[name] = []
self._extra_arrays[name] = np.array(())

elif extra:
raise ValueError("Initial values for extra properties must be"
Expand All @@ -503,12 +499,7 @@ def __getattr__(self, name):
if name not in self._extra_lists:
raise AttributeError("'{}' object has no attribute '{}'".format(
self.__class__.__name__, name))
L = self._extra_lists[name]
A = self._extra_arrays[name]
if len(L) != len(A):
A = np.array(L)
self._extra_arrays[name] = A
return A
return np.array(self._extra_lists[name])

def __call__(self, *species):
return SolutionArray(self._phase[species], states=self._states,
Expand Down Expand Up @@ -577,10 +568,8 @@ def sort(self, col, reverse=False):
if reverse:
indices = indices[::-1]
self._states = [self._states[ix] for ix in indices]
for k, v in self._extra_arrays.items():
new = v[indices]
self._extra_arrays[k] = new
self._extra_lists[k] = list(new)
for k, v in self._extra_lists.items():
self._extra_lists[k] = list(np.array(v)[indices])

def equilibrate(self, *args, **kwargs):
""" See `ThermoPhase.equilibrate` """
Expand Down Expand Up @@ -635,15 +624,15 @@ def collect_data(self, cols=('extra','T','density','Y'), threshold=0,
expanded_cols = []
for c in cols:
if c == 'extra':
expanded_cols.extend(self._extra_arrays)
expanded_cols.extend(self._extra_lists)
else:
expanded_cols.append(c)

species_names = set(self.species_names)
for c in expanded_cols:
single_species = False
# Determine labels for the items in the current group of columns
if c in self._extra_arrays:
if c in self._extra_lists:
collabels = [c]
elif c in self._scalar:
collabels = [c]
Expand Down

0 comments on commit 2728a6f

Please sign in to comment.