Skip to content

Commit

Permalink
Merge pull request #557 from clara-parabricks/paf-compare
Browse files Browse the repository at this point in the history
[pygenomeworks] Updates the evaluate_paf script.
  • Loading branch information
rilango authored Sep 14, 2020
2 parents 0e9a6f3 + cc1a835 commit b89a642
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 47 deletions.
182 changes: 135 additions & 47 deletions pygenomeworks/bin/evaluate_paf
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,61 @@
import argparse
from collections import defaultdict

import intervaltree

from genomeworks.io import pafio


def match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance):
def points_equal(fixed, point, slop):
return max(0, int(fixed) - slop) <= int(point) <= (int(fixed) + slop)

def within(val, target, tolerance=0.05):
return abs(float(val) - float(target)) <= tolerance

def calculate_reciprocal_overlap(record, other):
q_overlap = min(record.query_end, other.query_end) - max(record.query_start, other.query_start)
q_total_len = max(record.query_start, other.query_start) - min(record.query_start, other.query_start)

t_overlap = min(record.target_end, other.target_end) - max(record.target_start, other.target_start)
t_total_len = max(record.target_end, other.target_end) - min(record.target_start, other.target_start)

return float(q_overlap + t_overlap) / float(q_total_len + t_total_len)


def _gen_interval(start, end, value, tolerance):
return intervaltree.Interval(max(0, int(start) - tolerance), end + tolerance, value)

def construct_interval_dictionaries(paf_record_list, tolerance):
"""
Constructs a dictionary[string->IntervalTree], where the keys of the dictionary are
query or target names and the values are interval trees. Each interval tree contains the full interval
(i.e., start to end + tolerance BP on either side) pointing to the corresponding PAF record(s).
"""

query_paf_dict = defaultdict(intervaltree.IntervalTree)
target_paf_dict = defaultdict(intervaltree.IntervalTree)

for record in paf_record_list:
q_interval = _gen_interval(record.query_start, record.query_end, record, tolerance)
t_interval = _gen_interval(record.target_start, record.target_end, record, tolerance)
query_paf_dict[record.query_sequence_name].add(q_interval)
target_paf_dict[record.target_sequence_name].add(t_interval)

return query_paf_dict, target_paf_dict

def records_equal(record, other, pos_tolerance):
query_start_valid = points_equal(record.query_start, other.query_start, pos_tolerance)
query_end_valid = points_equal(record.query_end, other.query_end, pos_tolerance)
target_start_valid = points_equal(record.target_start, other.target_start, pos_tolerance)
target_end_valid = points_equal(record.target_end, other.target_end, pos_tolerance)

strands_equal = record.relative_strand == other.relative_strand

equal = query_start_valid and target_start_valid and query_end_valid and target_end_valid and strands_equal

return equal, query_start_valid, query_end_valid, target_start_valid, target_end_valid, strands_equal

def match_overlaps(record, other, pos_tolerance, min_reciprocal_overlap):
"""Given two sets of query and target ranges, check if the query and target ranges
fall within a specified tolerance of each other.
Expand All @@ -38,18 +89,19 @@ def match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance):
Returns: Boolean indicating query and target match.
"""

query_start_valid = abs(query_0[0] - query_1[0]) < pos_tolerance
query_end_valid = abs(query_0[1] - query_1[1]) < pos_tolerance
target_start_valid = abs(target_0[0] - target_1[0]) < pos_tolerance
target_end_valid = abs(target_0[1] - target_1[1]) < pos_tolerance
equal, query_start_valid, query_end_valid, target_start_valid, target_end_valid, strands_equal = records_equal(record, other, pos_tolerance)

reciprocal = calculate_reciprocal_overlap(record, other) > min_reciprocal_overlap

match = query_start_valid and query_end_valid and target_start_valid \
and target_end_valid
match = equal or reciprocal

return {"query_start_valid": query_start_valid,
"query_end_valid": query_end_valid,
"target_start_valid": target_start_valid,
"target_end_valid": target_end_valid,
"reciprocal_overlaps": reciprocal,
"strands_equal" : strands_equal,
"equal" : equal,
"match": match}
return match

Expand All @@ -65,8 +117,18 @@ def generate_key(name_1, name_2):
"""
return "{}_{}".format(name_1, name_2)

def _swap(val, dest):
return dest, val

def reverse_record(record):
query_sequence_name, target_sequence_name = _swap(record.query_sequence_name, record.target_sequence_name)
query_sequence_length, target_sequence_length = _swap(record.query_sequence_length, record.target_sequence_length)
query_start, target_start = _swap(record.query_start, record.target_start)
query_end, target_end = _swap(record.query_end, record.target_end)

return pafio.Overlap(query_sequence_name, query_sequence_length, query_start, query_end, record.relative_strand, target_sequence_name, target_sequence_length, target_start, target_end, record.num_residue_matches, record.alignment_block_length, record.mapping_quality, record.tags)

def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_self_mappings=True):
def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance, min_reciprocal, skip_self_mappings=True):
"""Given a truth and test set PAF file, count number of in/incorrectly detected, and non-detected overlaps
Args:
truth_paf_filepath (str): Path to truth set PAF file
Expand All @@ -78,7 +140,10 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_
"""

# Put the truth paf into a dictionary:
truth_overlaps = defaultdict(list)
truth_query_intervals = None
truth_target_intervals = None
truth_keys = defaultdict(int)
truth_records = []

num_true_overlaps = 0
for truth_overlap in pafio.read_paf(truth_paf_filepath):
Expand All @@ -87,14 +152,15 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_
continue

key = generate_key(truth_overlap.query_sequence_name, truth_overlap.target_sequence_name)

truth_overlaps[key].append(truth_overlap)
truth_keys[key] += 1
truth_records.append(truth_overlap)
num_true_overlaps += 1
truth_query_intervals, truth_target_intervals = construct_interval_dictionaries(truth_records, pos_tolerance)

true_positive_count = 0
false_positive_count = 0
false_negative_count = 0

test_overlap_count = 0
print("{} true overlaps in truth set".format(num_true_overlaps))

seen_test_overlap_keys = set()
Expand All @@ -108,53 +174,69 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_
if skip_self_mappings and \
(test_overlap.query_sequence_name == test_overlap.target_sequence_name):
continue

query_0 = (test_overlap.query_start, test_overlap.query_end)
target_0 = (test_overlap.target_start, test_overlap.target_end)
test_overlap_count += 1
# query_0 = (test_overlap.query_start, test_overlap.query_end)
# target_0 = (test_overlap.target_start, test_overlap.target_end)

key = generate_key(test_overlap.query_sequence_name, test_overlap.target_sequence_name)
key_reversed = generate_key(test_overlap.target_sequence_name, test_overlap.query_sequence_name)

if (key in seen_test_overlap_keys) or (key_reversed in seen_test_overlap_keys):
continue
# if (key in seen_test_overlap_keys) or (key_reversed in seen_test_overlap_keys):
# continue

seen_test_overlap_keys.add(key)
seen_test_overlap_keys.add(key_reversed)
# seen_test_overlap_keys.add(key)
# seen_test_overlap_keys.add(key_reversed)

found_match = False
if key in truth_overlaps:
for truth_overlap in truth_overlaps[key]:
query_1 = (truth_overlap.query_start, truth_overlap.query_end)
target_1 = (truth_overlap.target_start, truth_overlap.target_end)

match_statistics = match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance)

if key in truth_keys:
for truth_interval in truth_query_intervals[test_overlap.query_sequence_name]:
truth_overlap = truth_interval.data
match_statistics = match_overlaps(truth_overlap, test_overlap, pos_tolerance, min_reciprocal)
incorrect_query_start += not match_statistics["query_start_valid"]
incorrect_query_end += not match_statistics["query_end_valid"]
incorrect_target_start += not match_statistics["target_start_valid"]
incorrect_target_end += not match_statistics["target_end_valid"]

if match_statistics["match"]:
true_positive_count += 1
found_match = True
break

elif key_reversed in truth_overlaps:
for truth_overlap in truth_overlaps[key_reversed]:
query_1 = (truth_overlap.target_start, truth_overlap.target_end)
target_1 = (truth_overlap.query_start, truth_overlap.query_end)

match_statistics = match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance)

if not found_match:
for truth_interval in truth_target_intervals[test_overlap.target_sequence_name]:
truth_overlap = truth_interval.data
match_statistics = match_overlaps(truth_overlap, test_overlap, pos_tolerance, min_reciprocal)
incorrect_query_start += not match_statistics["query_start_valid"]
incorrect_query_end += not match_statistics["query_end_valid"]
incorrect_target_start += not match_statistics["target_start_valid"]
incorrect_target_end += not match_statistics["target_end_valid"]
if match_statistics["match"]:
true_positive_count += 1
found_match = True
break
if not found_match and key_reversed in truth_keys:
test_overlap = reverse_record(test_overlap)
for truth_interval in truth_query_intervals[key_reversed]:
truth_overlap = truth_interval.data
match_statistics = match_overlaps(truth_overlap, test_overlap)
incorrect_query_start += not match_statistics["query_start_valid"]
incorrect_query_end += not match_statistics["query_end_valid"]
incorrect_target_start += not match_statistics["target_start_valid"]
incorrect_target_end += not match_statistics["target_end_valid"]

if match_statistics["match"]:
true_positive_count += 1
found_match = True
break
if not found_match:
for truth_interval in truth_target_intervals[key_reversed]:
truth_overlap = truth_interval.data
match_statistics = match_overlaps(truth_overlap, test_overlap)
incorrect_query_start += not match_statistics["query_start_valid"]
incorrect_query_end += not match_statistics["query_end_valid"]
incorrect_target_start += not match_statistics["target_start_valid"]
incorrect_target_end += not match_statistics["target_end_valid"]
if match_statistics["match"]:
true_positive_count += 1
found_match = True
break

if not found_match:
false_positive_count += 1
Expand All @@ -169,32 +251,38 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_
incorrect_query_start,
incorrect_query_end,
incorrect_target_start,
incorrect_target_end)
incorrect_target_end,
test_overlap_count)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Given a truth (reference) and test set of overlaps in PAF format,\
calculate precision and recall")
parser.add_argument('--truth_paf',
parser.add_argument("-T", "--truth", dest="truth",
type=str,
default='truth.paf')
parser.add_argument('--test_paf',
default=None, required=True)
parser.add_argument("-i", "--test", dest="test",
type=str,
default='test.paf')
parser.add_argument('--pos_tolerance',
default=None, required=True)
parser.add_argument("-s", '--slop', dest="pos_tolerance",
type=int,
default=400,
help="Position tolerance around truth set interval to count as successful match.")
default=200,
help="Number of basepairs to tolerate on either side of an interval (or record) to consider the two records equal [200].")
parser.add_argument("-r", "--reciprocal-cutoff", dest="min_reciprocal", type=float,
required=False, default=0.9, help="Amount of reciprocal overlap required to consider two overlaps the same [0.9]")
parser.add_argument('--skip_self_mapping',
action="store_true",
help="Skip checking overlaps where query/target name are same")

args = parser.parse_args()

true_positives, false_positives, false_negatives, incorrect_query_start, incorrect_query_end, incorrect_target_start, \
incorrect_target_end = evaluate_paf(args.truth_paf, args.test_paf,
args.pos_tolerance, args.skip_self_mapping)

incorrect_target_end, total_test_records = evaluate_paf(args.truth,
args.test,
args.pos_tolerance,
args.min_reciprocal,
args.skip_self_mapping)
print("Total test records:", total_test_records)
print("True positives: ", true_positives)
print("False positives: ", false_positives)
print("False negatives: ", false_negatives)
Expand Down
1 change: 1 addition & 0 deletions pygenomeworks/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


Cython==0.29.12
intervaltree==3.1.0
networkx==2.4
numpy==1.16.3
pytest==4.4.1
Expand Down

0 comments on commit b89a642

Please sign in to comment.