Skip to content

Commit

Permalink
handle inferred internal nodes in coar
Browse files Browse the repository at this point in the history
  • Loading branch information
psathyrella committed Mar 8, 2024
1 parent b8ffed2 commit 4cb0cb3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
38 changes: 17 additions & 21 deletions python/coar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,13 @@ def reconstruct_lineage(tree, node):
node = node.parent_node

# ----------------------------------------------------------------------------------------
def find_node(tree, seq, uid):
nodes = [n for n in tree.leaf_node_iter() if n.seq == seq]
if len(nodes) > 1:
nodes = [n for n in nodes if n.taxon.label == uid]
dbgpair = ('seq', seq) if seq is not None else ('uid', uid)
if len(nodes) == 0:
raise Exception('couldn\'t find node with %s %s' % dbgpair)
elif len(nodes) > 1:
raise Exception('found multiple nodes with %s %s' % dbgpair)
return nodes[0]
def find_node(inf_nodes, seq, uid):
if seq not in inf_nodes or uid not in inf_nodes[seq]:
raise Exception('couldn\'t find node: \'%s\' %s' % (uid, seq))
return inf_nodes[seq][uid]

# ----------------------------------------------------------------------------------------
def align_lineages(node_t, tree_t, tree_i, gap_penalty_pct=0, known_root=True, allow_double_gap=False, test=False, debug=False):
def align_lineages(node_t, tree_t, tree_i, inf_nodes, gap_penalty_pct=0, known_root=True, allow_double_gap=False, test=False, debug=False):
'''
Standard implementation of a Needleman-Wunsch algorithm as described here:
http://telliott99.blogspot.com/2009/08/alignment-needleman-wunsch.html
Expand Down Expand Up @@ -99,7 +93,9 @@ def check_test_values(alignment_score, max_penalty, align_t, align_i):
uids_t = ['naive', 'a1', 'a2', 'leaf']
uids_i = ['naive', 'a1', 'leaf']
else:
node_i = find_node(tree_i, node_t.seq, node_t.taxon.label) # looks first by seq, then disambuguates with uid (if you only look by uid, if the inference swaps two nearby internal/leaf nodes with the same seq it'll crash)
node_i = find_node(inf_nodes, node_t.seq, node_t.taxon.label) # looks first by seq, then disambuguates with uid (if you only look by uid, if the inference swaps two nearby internal/leaf nodes with the same seq it'll crash)
if not node_i.is_leaf():
print(' %s inferred node %s is internal' % (utils.wrnstr(), node_t.taxon.label))
(lin_t, uids_t), (lin_i, uids_i) = [reconstruct_lineage(t, n) for t, n in [(tree_t, node_t), (tree_i, node_i)]]
# lineages must be longer than just the root and the terminal node
if len(lin_t) <= 2 and len(lin_i) <= 2:
Expand Down Expand Up @@ -203,7 +199,7 @@ def check_test_values(alignment_score, max_penalty, align_t, align_i):
max_seq_len = max(len(s) for slist in [align_t, align_i] for s in slist)
max_uid_len = max(len(u) for u in uids_i + uids_t)
print(' aligned lineages:')
print(' hdist %s %s %s' % (utils.wfmt('true ids', max_uid_len), utils.wfmt('inf ids', max_uid_len), utils.wfmt('inferred seqs', max_seq_len, jfmt='-')))
print(' hdist %s %s %s' % (utils.wfmt('true id', max_uid_len), utils.wfmt('inf id', max_uid_len), utils.wfmt('inferred seq', max_seq_len, jfmt='-')))
for uid_t, uid_i, seq_t, seq_i in zip(aln_ids['true'], aln_ids['inf'], align_t, align_i):
def cfn(s): return utils.color('blue', s, width=max_seq_len, padside='right') if s == gap_seq else s
str_t, str_i = cfn(seq_t), cfn(seq_i)
Expand All @@ -223,17 +219,17 @@ def ustr(u): return utils.color('blue' if u=='gap' else None, utils.wfmt(u, max_

# ----------------------------------------------------------------------------------------
def COAR(true_tree, inferred_tree, known_root=True, allow_double_gap=False, debug=False):
lineage_dists, n_skipped = list(), collections.OrderedDict([('inferred-internal', []), ('missing-leaf', []), ('too-short', [])])
inf_leaf_nodes = [n.taxon.label for n in inferred_tree.leaf_node_iter()] # just so we can skip true leaf nodes that were inferred to be internal
lineage_dists, n_skipped = list(), collections.OrderedDict([('missing-leaf', []), ('too-short', [])])
inf_nodes = {} # inferred nodes, indexed by seq (for use by find_node() above)
for inode in inferred_tree.preorder_node_iter():
if inode.seq not in inf_nodes:
inf_nodes[inode.seq] = {}
inf_nodes[inode.seq][inode.taxon.label] = inode
for node_t in true_tree.leaf_node_iter():
nlabel = node_t.taxon.label
if debug:
print('%s %3d %s' % (nlabel, len(node_t.seq), node_t.seq))
if nlabel not in inf_leaf_nodes:
is_internal = any(n.taxon.label == nlabel for n in inferred_tree.preorder_node_iter())
n_skipped['inferred-internal' if is_internal else 'missing-leaf'].append(nlabel)
continue
aln_res = align_lineages(node_t, true_tree, inferred_tree, known_root=known_root, allow_double_gap=allow_double_gap, debug=debug)
aln_res = align_lineages(node_t, true_tree, inferred_tree, inf_nodes, known_root=known_root, allow_double_gap=allow_double_gap, debug=debug)
if aln_res is None: # skip lineages less than three members long
n_skipped['too-short'].append(nlabel)
continue
Expand All @@ -247,7 +243,7 @@ def COAR(true_tree, inferred_tree, known_root=True, allow_double_gap=False, debu
print(' max penalty not less than zero: %.3f' % max_penalty)

if any(len(v) > 0 for v in n_skipped.values()): # this really shouldn't say 'warning' in yello if it's only 'too-short', but oh well it's hard ish to change
dstrs = {'missing-leaf' : '%d missing from inferred tree: %s', 'inferred-internal' : '%d were internal in inferred tree: %s', 'too-short' : '%d lineages had only two members: %s'}
dstrs = {'missing-leaf' : '%d missing from inferred tree: %s', 'too-short' : '%d lineages had only two members: %s'}
print(' %s skipped %d / %d true leaf nodes in coar calculation (%s)' % (utils.wrnstr(), sum(len(v) for v in n_skipped.values()), len(list(true_tree.leaf_node_iter())), ', '.join((dstrs[k]%(len(n_skipped[k]), ' '.join(sorted(n_skipped[k])))) for k in n_skipped if len(n_skipped[k])>0)))

if len(lineage_dists) == 0: # max_penalty is 0 when all lineages have less than three members
Expand Down
2 changes: 1 addition & 1 deletion python/treeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def upr(u): return utils.wfmt(u, max_len, jfmt='-')
if debug > 1:
print(' %s %s %s %s %s ' % (utils.color('blue' if hdist==0 else None, str(hdist), width=3), upr(l1), upr(l2), upr(mnode_t.taxon.label), upr(mnode_i.taxon.label)))
if debug:
print(' mrca dist totals over %d leaf pairs' % n_pairs)
print(' mrca dist totals over %d leaf pairs' % totals['n_pairs'])
for dtype in ['mut', 'len']:
print(' hdist / %s: %d / %d = %.3f' % (dtype, totals['hdist'], totals[dtype], totals['hdist'] / float(totals[dtype])))

Expand Down

0 comments on commit 4cb0cb3

Please sign in to comment.