Skip to content

Commit

Permalink
classmethod thin functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Apr 13, 2021
1 parent 3c41f8e commit 7ff2a53
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
22 changes: 19 additions & 3 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,22 @@ def thin_indices(self, factor, weights=None):
"""
if weights is None:
weights = self.weights
return WeightedSamples.thin_indices_single_samples(factor, weights)

@staticmethod
def thin_indices_and_weights(factor, weights):
"""
Returns indices and new weights for use when thinning samples.
:param factor: thin factor
:param weights: initial weight (counts) per sample point
:return: (unique index, counts) tuple of sample index values to keep and new weights
"""
thin_ix = WeightedSamples.thin_indices_single_samples(factor, weights)
return np.unique(thin_ix, return_counts=True)

@staticmethod
def thin_indices_single_samples(factor, weights):
numrows = len(weights)
norm1 = np.sum(weights)
weights = weights.astype(int)
Expand Down Expand Up @@ -914,8 +930,7 @@ def weighted_thin(self, factor):
This function also preserves separate chains.
:param factor: The (integer) factor to thin by
"""
thin_ix = self.thin_indices(factor)
unique, counts = np.unique(thin_ix, return_counts=True)
unique, counts = self.thin_indices_and_weights(factor, self.weights)
self.setSamples(self.samples[unique, :],
loglikes=None if self.loglikes is None
else self.loglikes[unique],
Expand Down Expand Up @@ -1137,7 +1152,8 @@ def _getParamIndices(self):
:return: A dict mapping the param name to the parameter index.
"""
if self.samples is not None and len(self.paramNames.names) != self.n:
raise WeightedSampleError("paramNames size does not match number of parameters in samples")
raise WeightedSampleError("paramNames size (%s) does not match number of "
"parameters in samples (%s)" % (len(self.paramNames.names), self.n))
index = dict()
for i, name in enumerate(self.paramNames.names):
index[name.name] = i
Expand Down
3 changes: 3 additions & 0 deletions getdist/paramnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ def numberOfName(self, name):
return i
return -1

def hasParam(self, name):
return self.numberOfName(name) != -1

def parsWithNames(self, names, error=False, renames=None):
"""
gets the list of :class:`ParamInfo` instances for given list of name strings.
Expand Down

0 comments on commit 7ff2a53

Please sign in to comment.