Skip to content

Commit

Permalink
Python changes for precision->likelihood_thresold
Browse files Browse the repository at this point in the history
Also update the Python implementation to match the C one
  • Loading branch information
jeromekelleher committed Sep 4, 2024
1 parent 2aa5e27 commit d575342
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
10 changes: 5 additions & 5 deletions _tsinfermodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1283,21 +1283,21 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds)
int err;
int extended_checks = 0;
static char *kwlist[] = {"tree_sequence_builder", "recombination",
"mismatch", "precision", "extended_checks", NULL};
"mismatch", "likelihood_threshold", "extended_checks", NULL};
TreeSequenceBuilder *tree_sequence_builder = NULL;
PyObject *recombination = NULL;
PyObject *mismatch = NULL;
PyArrayObject *recombination_array = NULL;
PyArrayObject *mismatch_array = NULL;
npy_intp *shape;
unsigned int precision = 22;
double likelihood_threshold = DBL_MIN;
int flags = 0;

self->ancestor_matcher = NULL;
self->tree_sequence_builder = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|Ii", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|di", kwlist,
&TreeSequenceBuilderType, &tree_sequence_builder,
&recombination, &mismatch, &precision,
&recombination, &mismatch, &likelihood_threshold,
&extended_checks)) {
goto out;
}
Expand Down Expand Up @@ -1343,7 +1343,7 @@ AncestorMatcher_init(AncestorMatcher *self, PyObject *args, PyObject *kwds)
self->tree_sequence_builder->tree_sequence_builder,
PyArray_DATA(recombination_array),
PyArray_DATA(mismatch_array),
precision, flags);
likelihood_threshold, flags);
if (err != 0) {
handle_library_error(err);
goto out;
Expand Down
11 changes: 6 additions & 5 deletions tsinfer/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,13 @@ def __init__(
tree_sequence_builder,
recombination=None,
mismatch=None,
precision=None,
likelihood_threshold=None,
extended_checks=False,
):
self.tree_sequence_builder = tree_sequence_builder
self.mismatch = mismatch
self.recombination = recombination
self.precision = precision
self.likelihood_threshold = likelihood_threshold
self.extended_checks = extended_checks
self.num_sites = tree_sequence_builder.num_sites
self.parent = None
Expand Down Expand Up @@ -705,7 +705,8 @@ def unset_allelic_state(self, site):
assert np.all(self.allelic_state == -1)

def update_site(self, site, haplotype_state):
n = self.tree_sequence_builder.num_match_nodes
# n = self.tree_sequence_builder.num_match_nodes
n = 1
rho = self.recombination[site]
mu = self.mismatch[site]
num_alleles = self.tree_sequence_builder.num_alleles[site]
Expand Down Expand Up @@ -763,13 +764,13 @@ def update_site(self, site, haplotype_state):
elif rho == 0:
raise _tsinfer.MatchImpossible(
"Matching failed with recombination=0, potentially due to "
"rounding issues. Try increasing the precision value"
"rounding issues. Try increasing the likelihood_threshold value"
)
raise AssertionError("Unexpected matching failure")

for u in self.likelihood_nodes:
x = self.likelihood[u] / max_L
self.likelihood[u] = round(x, self.precision)
self.likelihood[u] = max(x, self.likelihood_threshold)

self.max_likelihood_node[site] = max_L_node
self.unset_allelic_state(site)
Expand Down
23 changes: 18 additions & 5 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def infer(
num_threads=0,
# Deliberately undocumented parameters below
precision=None,
likelihood_threshold=None,
engine=constants.C_ENGINE,
progress_monitor=None,
time_units=None,
Expand Down Expand Up @@ -349,6 +350,7 @@ def infer(
recombination_rate=recombination_rate,
mismatch_ratio=mismatch_ratio,
precision=precision,
likelihood_threshold=likelihood_threshold,
path_compression=path_compression,
progress_monitor=progress_monitor,
time_units=time_units,
Expand All @@ -362,6 +364,7 @@ def infer(
recombination_rate=recombination_rate,
mismatch_ratio=mismatch_ratio,
precision=precision,
likelihood_threshold=likelihood_threshold,
post_process=post_process,
path_compression=path_compression,
progress_monitor=progress_monitor,
Expand Down Expand Up @@ -457,6 +460,7 @@ def match_ancestors(
recombination=None, # See :class:`Matcher`
mismatch=None, # See :class:`Matcher`
precision=None,
likelihood_threshold=None,
engine=constants.C_ENGINE,
progress_monitor=None,
extended_checks=False,
Expand Down Expand Up @@ -514,6 +518,7 @@ def match_ancestors(
path_compression=path_compression,
num_threads=num_threads,
precision=precision,
likelihood_threshold=likelihood_threshold,
extended_checks=extended_checks,
engine=engine,
progress_monitor=progress_monitor,
Expand Down Expand Up @@ -639,6 +644,7 @@ def match_samples(
recombination=None, # See :class:`Matcher`
mismatch=None, # See :class:`Matcher`
precision=None,
likelihood_threshold=None,
extended_checks=False,
engine=constants.C_ENGINE,
progress_monitor=None,
Expand Down Expand Up @@ -723,6 +729,7 @@ def match_samples(
path_compression=path_compression,
num_threads=num_threads,
precision=precision,
likelihood_threshold=likelihood_threshold,
extended_checks=extended_checks,
engine=engine,
progress_monitor=progress_monitor,
Expand Down Expand Up @@ -1141,6 +1148,7 @@ def __init__(
recombination=None,
mismatch=None,
precision=None,
likelihood_threshold=None,
extended_checks=False,
engine=constants.C_ENGINE,
progress_monitor=None,
Expand Down Expand Up @@ -1233,11 +1241,16 @@ def __init__(
if not (np.all(mismatch >= 0) and np.all(mismatch <= 1)):
raise ValueError("Underlying mismatch probabilities not between 0 & 1")

if precision is None:
precision = 13
if precision is not None and likelihood_threshold is not None:
raise ValueError("Cannot specify likelihood_threshold and precision")
if precision is not None:
likelihood_threshold = pow(10, -precision)
if likelihood_threshold is None:
likelihood_threshold = 1e-13 # ~Same as previous precision default.

self.recombination[1:] = recombination
self.mismatch[:] = mismatch
self.precision = precision
self.likelihood_threshold = likelihood_threshold

if len(recombination) == 0:
logger.info("Fewer than two inference sites: no recombination possible")
Expand All @@ -1261,7 +1274,7 @@ def __init__(
f"mean={np.mean(mismatch):.5g}"
)
logger.info(
f"Matching using {precision} digits of precision in likelihood calcs"
f"Matching using likelihood_threshold of {likelihood_threshold:.5g}"
)

if engine == constants.C_ENGINE:
Expand Down Expand Up @@ -1303,7 +1316,7 @@ def __init__(
self.tree_sequence_builder,
recombination=self.recombination,
mismatch=self.mismatch,
precision=precision,
likelihood_threshold=likelihood_threshold,
extended_checks=self.extended_checks,
)
for _ in range(num_threads)
Expand Down

0 comments on commit d575342

Please sign in to comment.