From bd9be3806ef966bf47afbd780be8d74173fa8d72 Mon Sep 17 00:00:00 2001 From: "Eric T. Dawson" Date: Mon, 14 Sep 2020 09:49:52 -0400 Subject: [PATCH] [pygenomeworks] Updates the evaluate_paf script. Provides a nearly complete rewrite of the evaluate_paf script, which now uses an interval tree for queries on either the query or target. Records are deemed identical either by having matching starts/ends or by reciprocal overlap. This also seems to fix an issue where identical PAF files passed as both test and truth could return precision/recall values less than one. --- pygenomeworks/bin/evaluate_paf | 182 ++++++++++++++++++++++++--------- pygenomeworks/requirements.txt | 1 + 2 files changed, 136 insertions(+), 47 deletions(-) diff --git a/pygenomeworks/bin/evaluate_paf b/pygenomeworks/bin/evaluate_paf index 639e27b9e..b29aacbd1 100755 --- a/pygenomeworks/bin/evaluate_paf +++ b/pygenomeworks/bin/evaluate_paf @@ -22,10 +22,61 @@ import argparse from collections import defaultdict +import intervaltree + from genomeworks.io import pafio -def match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance): +def points_equal(fixed, point, slop): + return max(0, int(fixed) - slop) <= int(point) <= (int(fixed) + slop) + +def within(val, target, tolerance=0.05): + return abs(float(val) - float(target)) <= tolerance + +def calculate_reciprocal_overlap(record, other): + q_overlap = min(record.query_end, other.query_end) - max(record.query_start, other.query_start) + q_total_len = max(record.query_start, other.query_start) - min(record.query_start, other.query_start) + + t_overlap = min(record.target_end, other.target_end) - max(record.target_start, other.target_start) + t_total_len = max(record.target_end, other.target_end) - min(record.target_start, other.target_start) + + return float(q_overlap + t_overlap) / float(q_total_len + t_total_len) + + +def _gen_interval(start, end, value, tolerance): + return intervaltree.Interval(max(0, int(start) - tolerance), end + tolerance, value) + +def construct_interval_dictionaries(paf_record_list, tolerance): + """ + Constructs a dictionary[string->IntervalTree], where the keys of the dictionary are + query or target names and the values are interval trees. Each interval tree contains the full interval + (i.e., start to end + tolerance BP on either side) pointing to the corresponding PAF record(s). + """ + + query_paf_dict = defaultdict(intervaltree.IntervalTree) + target_paf_dict = defaultdict(intervaltree.IntervalTree) + + for record in paf_record_list: + q_interval = _gen_interval(record.query_start, record.query_end, record, tolerance) + t_interval = _gen_interval(record.target_start, record.target_end, record, tolerance) + query_paf_dict[record.query_sequence_name].add(q_interval) + target_paf_dict[record.target_sequence_name].add(t_interval) + + return query_paf_dict, target_paf_dict + +def records_equal(record, other, pos_tolerance): + query_start_valid = points_equal(record.query_start, other.query_start, pos_tolerance) + query_end_valid = points_equal(record.query_end, other.query_end, pos_tolerance) + target_start_valid = points_equal(record.target_start, other.target_start, pos_tolerance) + target_end_valid = points_equal(record.target_end, other.target_end, pos_tolerance) + + strands_equal = record.relative_strand == other.relative_strand + + equal = query_start_valid and target_start_valid and query_end_valid and target_end_valid and strands_equal + + return equal, query_start_valid, query_end_valid, target_start_valid, target_end_valid, strands_equal + +def match_overlaps(record, other, pos_tolerance, min_reciprocal_overlap): """Given two sets of query and target ranges, check if the query and target ranges fall within a specified tolerance of each other. @@ -38,18 +89,19 @@ def match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance): Returns: Boolean indicating query and target match. """ - query_start_valid = abs(query_0[0] - query_1[0]) < pos_tolerance - query_end_valid = abs(query_0[1] - query_1[1]) < pos_tolerance - target_start_valid = abs(target_0[0] - target_1[0]) < pos_tolerance - target_end_valid = abs(target_0[1] - target_1[1]) < pos_tolerance + equal, query_start_valid, query_end_valid, target_start_valid, target_end_valid, strands_equal = records_equal(record, other, pos_tolerance) + + reciprocal = calculate_reciprocal_overlap(record, other) > min_reciprocal_overlap - match = query_start_valid and query_end_valid and target_start_valid \ - and target_end_valid + match = equal or reciprocal return {"query_start_valid": query_start_valid, "query_end_valid": query_end_valid, "target_start_valid": target_start_valid, "target_end_valid": target_end_valid, + "reciprocal_overlaps": reciprocal, + "strands_equal" : strands_equal, + "equal" : equal, "match": match} return match @@ -65,8 +117,18 @@ def generate_key(name_1, name_2): """ return "{}_{}".format(name_1, name_2) +def _swap(val, dest): + return dest, val + +def reverse_record(record): + query_sequence_name, target_sequence_name = _swap(record.query_sequence_name, record.target_sequence_name) + query_sequence_length, target_sequence_length = _swap(record.query_sequence_length, record.target_sequence_length) + query_start, target_start = _swap(record.query_start, record.target_start) + query_end, target_end = _swap(record.query_end, record.target_end) + + return pafio.Overlap(query_sequence_name, query_sequence_length, query_start, query_end, record.relative_strand, target_sequence_name, target_sequence_length, target_start, target_end, record.num_residue_matches, record.alignment_block_length, record.mapping_quality, record.tags) -def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_self_mappings=True): +def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance, min_reciprocal, skip_self_mappings=True): """Given a truth and test set PAF file, count number of in/incorrectly detected, and non-detected overlaps Args: truth_paf_filepath (str): Path to truth set PAF file @@ -78,7 +140,10 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_ """ # Put the truth paf into a dictionary: - truth_overlaps = defaultdict(list) + truth_query_intervals = None + truth_target_intervals = None + truth_keys = defaultdict(int) + truth_records = [] num_true_overlaps = 0 for truth_overlap in pafio.read_paf(truth_paf_filepath): @@ -87,14 +152,15 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_ continue key = generate_key(truth_overlap.query_sequence_name, truth_overlap.target_sequence_name) - - truth_overlaps[key].append(truth_overlap) + truth_keys[key] += 1 + truth_records.append(truth_overlap) num_true_overlaps += 1 + truth_query_intervals, truth_target_intervals = construct_interval_dictionaries(truth_records, pos_tolerance) true_positive_count = 0 false_positive_count = 0 false_negative_count = 0 - + test_overlap_count = 0 print("{} true overlaps in truth set".format(num_true_overlaps)) seen_test_overlap_keys = set() @@ -108,53 +174,69 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_ if skip_self_mappings and \ (test_overlap.query_sequence_name == test_overlap.target_sequence_name): continue - - query_0 = (test_overlap.query_start, test_overlap.query_end) - target_0 = (test_overlap.target_start, test_overlap.target_end) + test_overlap_count += 1 + # query_0 = (test_overlap.query_start, test_overlap.query_end) + # target_0 = (test_overlap.target_start, test_overlap.target_end) key = generate_key(test_overlap.query_sequence_name, test_overlap.target_sequence_name) key_reversed = generate_key(test_overlap.target_sequence_name, test_overlap.query_sequence_name) - if (key in seen_test_overlap_keys) or (key_reversed in seen_test_overlap_keys): - continue + # if (key in seen_test_overlap_keys) or (key_reversed in seen_test_overlap_keys): + # continue - seen_test_overlap_keys.add(key) - seen_test_overlap_keys.add(key_reversed) + # seen_test_overlap_keys.add(key) + # seen_test_overlap_keys.add(key_reversed) found_match = False - if key in truth_overlaps: - for truth_overlap in truth_overlaps[key]: - query_1 = (truth_overlap.query_start, truth_overlap.query_end) - target_1 = (truth_overlap.target_start, truth_overlap.target_end) - - match_statistics = match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance) - + if key in truth_keys: + for truth_interval in truth_query_intervals[test_overlap.query_sequence_name]: + truth_overlap = truth_interval.data + match_statistics = match_overlaps(truth_overlap, test_overlap, pos_tolerance, min_reciprocal) incorrect_query_start += not match_statistics["query_start_valid"] incorrect_query_end += not match_statistics["query_end_valid"] incorrect_target_start += not match_statistics["target_start_valid"] incorrect_target_end += not match_statistics["target_end_valid"] - if match_statistics["match"]: true_positive_count += 1 found_match = True break - - elif key_reversed in truth_overlaps: - for truth_overlap in truth_overlaps[key_reversed]: - query_1 = (truth_overlap.target_start, truth_overlap.target_end) - target_1 = (truth_overlap.query_start, truth_overlap.query_end) - - match_statistics = match_overlaps(query_0, query_1, target_0, target_1, pos_tolerance) - + if not found_match: + for truth_interval in truth_target_intervals[test_overlap.target_sequence_name]: + truth_overlap = truth_interval.data + match_statistics = match_overlaps(truth_overlap, test_overlap, pos_tolerance, min_reciprocal) + incorrect_query_start += not match_statistics["query_start_valid"] + incorrect_query_end += not match_statistics["query_end_valid"] + incorrect_target_start += not match_statistics["target_start_valid"] + incorrect_target_end += not match_statistics["target_end_valid"] + if match_statistics["match"]: + true_positive_count += 1 + found_match = True + break + if not found_match and key_reversed in truth_keys: + test_overlap = reverse_record(test_overlap) + for truth_interval in truth_query_intervals[key_reversed]: + truth_overlap = truth_interval.data + match_statistics = match_overlaps(truth_overlap, test_overlap) incorrect_query_start += not match_statistics["query_start_valid"] incorrect_query_end += not match_statistics["query_end_valid"] incorrect_target_start += not match_statistics["target_start_valid"] incorrect_target_end += not match_statistics["target_end_valid"] - if match_statistics["match"]: true_positive_count += 1 found_match = True break + if not found_match: + for truth_interval in truth_target_intervals[key_reversed]: + truth_overlap = truth_interval.data + match_statistics = match_overlaps(truth_overlap, test_overlap) + incorrect_query_start += not match_statistics["query_start_valid"] + incorrect_query_end += not match_statistics["query_end_valid"] + incorrect_target_start += not match_statistics["target_start_valid"] + incorrect_target_end += not match_statistics["target_end_valid"] + if match_statistics["match"]: + true_positive_count += 1 + found_match = True + break if not found_match: false_positive_count += 1 @@ -169,22 +251,25 @@ def evaluate_paf(truth_paf_filepath, test_paf_filepath, pos_tolerance=400, skip_ incorrect_query_start, incorrect_query_end, incorrect_target_start, - incorrect_target_end) + incorrect_target_end, + test_overlap_count) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Given a truth (reference) and test set of overlaps in PAF format,\ calculate precision and recall") - parser.add_argument('--truth_paf', + parser.add_argument("-T", "--truth", dest="truth", type=str, - default='truth.paf') - parser.add_argument('--test_paf', + default=None, required=True) + parser.add_argument("-i", "--test", dest="test", type=str, - default='test.paf') - parser.add_argument('--pos_tolerance', + default=None, required=True) + parser.add_argument("-s", '--slop', dest="pos_tolerance", type=int, - default=400, - help="Position tolerance around truth set interval to count as successful match.") + default=200, + help="Number of basepairs to tolerate on either side of an interval (or record) to consider the two records equal [200].") + parser.add_argument("-r", "--reciprocal-cutoff", dest="min_reciprocal", type=float, + required=False, default=0.9, help="Amount of reciprocal overlap required to consider two overlaps the same [0.9]") parser.add_argument('--skip_self_mapping', action="store_true", help="Skip checking overlaps where query/target name are same") @@ -192,9 +277,12 @@ if __name__ == "__main__": args = parser.parse_args() true_positives, false_positives, false_negatives, incorrect_query_start, incorrect_query_end, incorrect_target_start, \ - incorrect_target_end = evaluate_paf(args.truth_paf, args.test_paf, - args.pos_tolerance, args.skip_self_mapping) - + incorrect_target_end, total_test_records = evaluate_paf(args.truth, + args.test, + args.pos_tolerance, + args.min_reciprocal, + args.skip_self_mapping) + print("Total test records:", total_test_records) print("True positives: ", true_positives) print("False positives: ", false_positives) print("False negatives: ", false_negatives) diff --git a/pygenomeworks/requirements.txt b/pygenomeworks/requirements.txt index 900868e1b..61e0674b5 100644 --- a/pygenomeworks/requirements.txt +++ b/pygenomeworks/requirements.txt @@ -17,6 +17,7 @@ Cython==0.29.12 +intervaltree==3.1.0 networkx==2.4 numpy==1.16.3 pytest==4.4.1