Skip to content

Commit b3366cf

Browse files
Progress
1 parent 33300de commit b3366cf

File tree

1 file changed

+65
-46
lines changed

1 file changed

+65
-46
lines changed

python/tests/test_haplotype_matching.py

Lines changed: 65 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ def node_values(self):
150150
d[u] = mapping[v]
151151
return d
152152

153+
@property
154+
def matrix_size(self):
155+
if self.match_all_nodes:
156+
return self.ts.num_nodes
157+
return self.ts.num_samples
158+
153159
def print_state(self):
154160
print("LsHMM state")
155161
print("match_all_nodes =", self.match_all_nodes)
@@ -435,12 +441,18 @@ def update_probabilities(self, site, haplotype_state):
435441

436442
def process_site(self, site, haplotype_state):
437443
self.update_probabilities(site, haplotype_state)
438-
# d1 = self.node_values()
444+
d1 = self.node_values()
439445
# print("PRE")
440-
# self.print_state()
446+
# # self.print_state()
441447
self.compress()
442-
# d2 = self.node_values()
443-
# assert d1 == d2
448+
d2 = self.node_values()
449+
if self.match_all_nodes:
450+
# We only get an exact match on all_nodes. For samples we just
451+
# guarantee that the *samples* have the same value
452+
assert d1 == d2
453+
else:
454+
for u in self.ts.samples():
455+
assert d1[u] == d2[u]
444456
# print("AFTER COMPRESS")
445457
# self.print_state()
446458
s = self.compute_normalisation_factor()
@@ -489,7 +501,7 @@ def initialise(self, value):
489501
self.T.append(ValueTransition(tree_node=u, value=value))
490502

491503
def run(self, h):
492-
n = self.ts.num_samples
504+
n = self.matrix_size
493505
self.initialise(1 / n)
494506
while self.tree.next():
495507
self.update_tree()
@@ -553,8 +565,9 @@ def compute_normalisation_factor(self):
553565
return s
554566

555567
def compute_next_probability(self, site_id, p_last, is_match, node):
568+
n = self.matrix_size
569+
# print("NEXT PROBA:", site_id, n)
556570
rho = self.rho[site_id]
557-
n = self.ts.num_samples
558571
p_e = self.compute_emission_proba(site_id, is_match)
559572
p_t = p_last * (1 - rho) + rho / n
560573
return p_t * p_e
@@ -584,7 +597,7 @@ def process_site(self, site, haplotype_state, s):
584597
# compress
585598
self.compress()
586599
b_last_sum = self.compute_normalisation_factor()
587-
n = self.ts.num_samples
600+
n = self.matrix_size
588601
rho = self.rho[site.id]
589602
for st in self.T:
590603
if st.tree_node != tskit.NULL:
@@ -624,7 +637,7 @@ def compute_normalisation_factor(self):
624637

625638
def compute_next_probability(self, site_id, p_last, is_match, node):
626639
rho = self.rho[site_id]
627-
n = self.ts.num_samples
640+
n = self.matrix_size
628641

629642
p_no_recomb = p_last * (1 - rho + rho / n)
630643
p_recomb = rho / n
@@ -668,7 +681,6 @@ class CompressedMatrix:
668681
def __init__(self, ts):
669682
self.ts = ts
670683
self.num_sites = ts.num_sites
671-
self.num_samples = ts.num_samples
672684
self.value_transitions = [None for _ in range(self.num_sites)]
673685
self.normalisation_factor = np.zeros(self.num_sites)
674686

@@ -697,14 +709,14 @@ def num_transitions(self):
697709
def get_site(self, site):
698710
return self.value_transitions[site]
699711

700-
def decode(self):
712+
def decode_samples(self):
701713
"""
702714
Decodes the tree encoding of the values into an explicit
703715
matrix.
704716
"""
705717
sample_index_map = np.zeros(self.ts.num_nodes, dtype=int) - 1
706718
sample_index_map[self.ts.samples()] = np.arange(self.ts.num_samples)
707-
A = np.zeros((self.num_sites, self.num_samples))
719+
A = np.zeros((self.num_sites, self.ts.num_samples))
708720
for tree in self.ts.trees():
709721
for site in tree.sites():
710722
for node, value in self.value_transitions[site.id]:
@@ -713,6 +725,22 @@ def decode(self):
713725
A[site.id, j] = value
714726
return A
715727

728+
def decode_nodes(self):
729+
# print("decode nodes")
730+
A = np.zeros((self.num_sites, self.ts.num_nodes))
731+
for tree in self.ts.trees():
732+
for site in tree.sites():
733+
for node, value in self.value_transitions[site.id]:
734+
# print("Decode:", site.id, node, value)
735+
for u in tree.nodes(node):
736+
A[site.id, u] = value
737+
return A
738+
739+
def decode(self, all_nodes=False):
740+
if all_nodes:
741+
return self.decode_nodes()
742+
return self.decode_samples()
743+
716744

717745
class ViterbiMatrix(CompressedMatrix):
718746
"""
@@ -1330,7 +1358,7 @@ def check_forward_matrix(
13301358
scale_mutation_based_on_n_alleles=False,
13311359
match_all_nodes=match_all_nodes,
13321360
)
1333-
F2 = cm.decode()
1361+
F2 = cm.decode(match_all_nodes)
13341362
ll_tree = np.sum(np.log10(cm.normalisation_factor))
13351363

13361364
if compare_lshmm:
@@ -1549,6 +1577,7 @@ def test_match_sample(self, u, h):
15491577
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True
15501578
)
15511579
nt.assert_array_equal([u] * 7, path)
1580+
15521581
fm = check_forward_matrix(
15531582
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True
15541583
)
@@ -1558,45 +1587,36 @@ def test_match_sample(self, u, h):
15581587
check_fb_matrix_integrity(fm, bm)
15591588

15601589

1561-
def check_fb_matrix_integrity(fm, bm):
1590+
def check_fb_matrix_integrity(fm, bm, match_all_nodes=False):
15621591
"""
15631592
Validate properties of the forward and backward matrices.
15641593
"""
1565-
F = fm.decode()
1566-
B = bm.decode()
1594+
F = fm.decode(match_all_nodes)
1595+
B = bm.decode(match_all_nodes)
15671596
assert F.shape == B.shape
15681597
for j in range(len(F)):
15691598
s = np.sum(B[j] * F[j])
1599+
# print(j, s)
15701600
np.testing.assert_allclose(s, 1)
15711601

15721602

1573-
def check_fb_matrices(ts, h):
1574-
fm = check_forward_matrix(ts, h)
1575-
bm = check_backward_matrix(ts, h, fm)
1576-
check_fb_matrix_integrity(fm, bm)
1603+
def check_fb_matrices(ts, h, match_all_nodes=False, **kwargs):
1604+
fm = check_forward_matrix(ts, h, match_all_nodes=match_all_nodes, **kwargs)
1605+
bm = check_backward_matrix(ts, h, fm, match_all_nodes=match_all_nodes, **kwargs)
1606+
check_fb_matrix_integrity(fm, bm, match_all_nodes=match_all_nodes)
15771607

15781608

15791609
def validate_match_all_nodes(ts, h, expected_path):
1580-
# path = check_viterbi(
1581-
# ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1582-
# )
1583-
# nt.assert_array_equal(expected_path, path)
1584-
fm = check_forward_matrix(
1610+
# START HERE: most of this is working except for Viterbi
1611+
path = check_viterbi(
15851612
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
15861613
)
1587-
F = fm.decode()
1588-
# print(cm.decode())
1589-
# cm.print_state()
1590-
bm = check_backward_matrix(
1591-
ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1592-
)
1593-
print("sites = ", ts.num_sites)
1594-
B = bm.decode()
1595-
print(F)
1596-
for j in range(ts.num_sites):
1597-
print(j, np.sum(B[j] * F[j]))
1614+
# print("Path = ", path)
1615+
nt.assert_array_equal(expected_path, path)
15981616

1599-
# sum(B[variant,:] * F[variant,:]) = 1
1617+
check_fb_matrices(
1618+
ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False
1619+
)
16001620

16011621

16021622
class TestSingleBalancedTreeAllNodesExample:
@@ -1692,19 +1712,18 @@ def ts():
16921712
("h", "expected_path"),
16931713
[
16941714
# Just samples
1695-
([1, 0, 0, 0, 0, 1, 1], [0] * 7),
1696-
# ([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1697-
# ([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1698-
# ([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1699-
# # Match root
1700-
# ([0, 0, 0, 0, 0, 0, 0], [7] * 7),
1715+
# fails on viterbi
1716+
# ([1, 0, 0, 0, 0, 1, 1], [0] * 7),
1717+
([0, 1, 0, 0, 1, 1, 0], [1] * 7),
1718+
([0, 0, 1, 0, 1, 1, 0], [2] * 7),
1719+
([0, 0, 0, 1, 0, 0, 1], [3] * 7),
1720+
# Match single internal node
1721+
([0, 0, 0, 0, 1, 1, 0], [4] * 7),
1722+
# Match root
1723+
([0, 0, 0, 0, 0, 0, 0], [7] * 7),
17011724
],
17021725
)
17031726
def test_match_all_nodes(self, h, expected_path):
1704-
# print()
1705-
# print(self.ts().draw_text())
1706-
# with open("tmp.svg", "w") as f:
1707-
# f.write(self.ts().draw_svg())
17081727
validate_match_all_nodes(self.ts(), h, expected_path)
17091728

17101729
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)