diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 4ba59d7f..a887f962 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -1283,21 +1283,21 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds) int err; int extended_checks = 0; static char *kwlist[] = {"tree_sequence_builder", "recombination", - "mismatch", "precision", "extended_checks", NULL}; + "mismatch", "likelihood_threshold", "extended_checks", NULL}; TreeSequenceBuilder *tree_sequence_builder = NULL; PyObject *recombination = NULL; PyObject *mismatch = NULL; PyArrayObject *recombination_array = NULL; PyArrayObject *mismatch_array = NULL; npy_intp *shape; - unsigned int precision = 22; + double likelihood_threshold = DBL_MIN; int flags = 0; self->ancestor_matcher = NULL; self->tree_sequence_builder = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|Ii", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|di", kwlist, &TreeSequenceBuilderType, &tree_sequence_builder, - &recombination, &mismatch, &precision, + &recombination, &mismatch, &likelihood_threshold, &extended_checks)) { goto out; } @@ -1343,7 +1343,7 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds) self->tree_sequence_builder->tree_sequence_builder, PyArray_DATA(recombination_array), PyArray_DATA(mismatch_array), - precision, flags); + likelihood_threshold, flags); if (err != 0) { handle_library_error(err); goto out; diff --git a/tsinfer/algorithm.py b/tsinfer/algorithm.py index 2254b475..30486979 100644 --- a/tsinfer/algorithm.py +++ b/tsinfer/algorithm.py @@ -636,13 +636,13 @@ def __init__( tree_sequence_builder, recombination=None, mismatch=None, - precision=None, + likelihood_threshold=None, extended_checks=False, ): self.tree_sequence_builder = tree_sequence_builder self.mismatch = mismatch self.recombination = recombination - self.precision = precision + self.likelihood_threshold = likelihood_threshold self.extended_checks = extended_checks self.num_sites = tree_sequence_builder.num_sites self.parent = None @@ -705,7 +705,8 @@ def unset_allelic_state(self, site): assert np.all(self.allelic_state == -1) def update_site(self, site, haplotype_state): - n = self.tree_sequence_builder.num_match_nodes + # n = self.tree_sequence_builder.num_match_nodes + n = 1 rho = self.recombination[site] mu = self.mismatch[site] num_alleles = self.tree_sequence_builder.num_alleles[site] @@ -763,13 +764,13 @@ def update_site(self, site, haplotype_state): elif rho == 0: raise _tsinfer.MatchImpossible( "Matching failed with recombination=0, potentially due to " - "rounding issues. Try increasing the precision value" + "rounding issues. Try increasing the likelihood_threshold value" ) raise AssertionError("Unexpected matching failure") for u in self.likelihood_nodes: x = self.likelihood[u] / max_L - self.likelihood[u] = round(x, self.precision) + self.likelihood[u] = max(x, self.likelihood_threshold) self.max_likelihood_node[site] = max_L_node self.unset_allelic_state(site) diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 6f51670b..284af8f9 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -269,6 +269,7 @@ def infer( num_threads=0, # Deliberately undocumented parameters below precision=None, + likelihood_threshold=None, engine=constants.C_ENGINE, progress_monitor=None, time_units=None, @@ -349,6 +350,7 @@ def infer( recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, precision=precision, + likelihood_threshold=likelihood_threshold, path_compression=path_compression, progress_monitor=progress_monitor, time_units=time_units, @@ -362,6 +364,7 @@ def infer( recombination_rate=recombination_rate, mismatch_ratio=mismatch_ratio, precision=precision, + likelihood_threshold=likelihood_threshold, post_process=post_process, path_compression=path_compression, progress_monitor=progress_monitor, @@ -457,6 +460,7 @@ def match_ancestors( recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, + likelihood_threshold=None, engine=constants.C_ENGINE, progress_monitor=None, extended_checks=False, @@ -514,6 +518,7 @@ def match_ancestors( path_compression=path_compression, num_threads=num_threads, precision=precision, + likelihood_threshold=likelihood_threshold, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, @@ -639,6 +644,7 @@ def match_samples( recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, + likelihood_threshold=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, @@ -723,6 +729,7 @@ def match_samples( path_compression=path_compression, num_threads=num_threads, precision=precision, + likelihood_threshold=likelihood_threshold, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, @@ -1141,6 +1148,7 @@ def __init__( recombination=None, mismatch=None, precision=None, + likelihood_threshold=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, @@ -1233,11 +1241,16 @@ def __init__( if not (np.all(mismatch >= 0) and np.all(mismatch <= 1)): raise ValueError("Underlying mismatch probabilities not between 0 & 1") - if precision is None: - precision = 13 + if precision is not None and likelihood_threshold is not None: + raise ValueError("Cannot specify likelihood_threshold and precision") + if precision is not None: + likelihood_threshold = pow(10, -precision) + if likelihood_threshold is None: + likelihood_threshold = 1e-13 # ~Same as previous precision default. + self.recombination[1:] = recombination self.mismatch[:] = mismatch - self.precision = precision + self.likelihood_threshold = likelihood_threshold if len(recombination) == 0: logger.info("Fewer than two inference sites: no recombination possible") @@ -1261,7 +1274,7 @@ def __init__( f"mean={np.mean(mismatch):.5g}" ) logger.info( - f"Matching using {precision} digits of precision in likelihood calcs" + f"Matching using likelihood_threshold of {likelihood_threshold:.5g}" ) if engine == constants.C_ENGINE: @@ -1303,7 +1316,7 @@ def __init__( self.tree_sequence_builder, recombination=self.recombination, mismatch=self.mismatch, - precision=precision, + likelihood_threshold=likelihood_threshold, extended_checks=self.extended_checks, ) for _ in range(num_threads)