Skip to content

Commit

Permalink
For EI-CoreBioinformatics#280: making the merging algorithm much fast…
Browse files Browse the repository at this point in the history
…er, thanks to Stack Overflow
  • Loading branch information
lucventurini committed Mar 4, 2020
1 parent 12df69a commit a1dbfaa
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 57 deletions.
61 changes: 13 additions & 48 deletions Mikado/parsers/blast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import gzip
import io
from . import HeaderError
# from Bio.Blast.NCBIXML import parse as xparser
from operator import itemgetter
from Bio.SearchIO import parse as bio_parser
import functools
xparser = functools.partial(bio_parser, format="blast-xml")
from ..utilities import overlap
import xml.etree.ElementTree
import numpy as np
xparser = functools.partial(bio_parser, format="blast-xml")


__author__ = 'Luca Venturini'
Expand Down Expand Up @@ -198,42 +197,6 @@ def sniff(self, default_header=None):
return valid, default_header, exc


def __calculate_merges(intervals: np.array):
"""
Internal function used by merge to perform the proper merging calculation.
:param intervals:
:return:
"""

if intervals.shape[0] == 1:
return intervals

new_intervals = np.ma.array(np.empty(intervals.shape, dtype=intervals.dtype),
dtype=intervals.dtype,
mask=True)

pos = 0
current = None

for iv in intervals:
if current is None:
current = iv
continue
else:
if overlap(current, iv, positive=False) >= 0:
current = (min(current[0], iv[0]),
max(current[1], iv[1]))
else:
new_intervals[pos] = current
current = iv
pos += 1

new_intervals[pos] = current
new_intervals = np.array(new_intervals[~new_intervals[:, 0].mask], dtype=new_intervals.dtype)
new_intervals = new_intervals[np.lexsort((new_intervals[:, 1], new_intervals[:, 0]))]
return new_intervals


def merge(intervals: [(int, int)], query_length=None, offset=1):
"""
This function is used to merge together intervals, which have to be supplied as a list
Expand All @@ -249,6 +212,8 @@ def merge(intervals: [(int, int)], query_length=None, offset=1):
:returns: merged intervals, length covered
"""

# Assume tuple of the form (start,end)
Expand All @@ -258,21 +223,21 @@ def merge(intervals: [(int, int)], query_length=None, offset=1):
raise ValueError("Invalid offset - only 0 and 1 allowed: {}".format(offset))

try:
intervals = np.array([sorted(_) for _ in intervals], dtype=np.int)
intervals = np.array(sorted([sorted(_) for _ in intervals], key=itemgetter(0)), dtype=np.int)
if intervals.shape[1] != 2:
raise ValueError("Invalid shape for intervals: {}".format(intervals.shape))
except (TypeError, ValueError):
raise TypeError("Invalid array for intervals: {}".format(intervals))

intervals = intervals[np.lexsort((intervals[:,1], intervals[:,0]))]
intervals = __calculate_merges(intervals)
total_length_covered = int(abs(intervals[:,1] - intervals[:,0] + offset).sum())

if not query_length:
query_length = int(abs(intervals[:,1].max() - intervals[:,0].min() + offset))
intervals.sort()
starts = intervals[:, 0]
ends = np.maximum.accumulate(intervals[:, 1])
valid = np.zeros(len(intervals) + 1, dtype=np.bool)
valid[0], valid[1:-1], valid[-1] = True, starts[1:] >= ends[:-1], True
intervals = np.vstack((starts[:][valid[:-1]], ends[:][valid[1:]])).T
total_length_covered = int(abs(intervals[:, 1] - intervals[:, 0] + offset).sum())

if query_length and total_length_covered > query_length:
raise AssertionError("Something went wrong, original length {}, total length {}".format(
query_length, total_length_covered))

return [(int(_[0]), int(_[1])) for _ in intervals], total_length_covered
return [tuple([int(_[0]), int(_[1])]) for _ in intervals], total_length_covered
11 changes: 6 additions & 5 deletions Mikado/serializers/blast_serializer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,18 @@ def hsp_sorter(val):
t_intervals.append((hsp.hit_start, hsp.hit_end))

q_merged_intervals, q_aligned = merge(q_intervals)
assert isinstance(q_aligned, np.int), (q_merged_intervals, q_aligned, type(q_aligned))
# assert isinstance(q_aligned, np.int), (q_merged_intervals, q_aligned, type(q_aligned))
hit_dict["query_aligned_length"] = min(qlength, q_aligned)
qstart, qend = q_merged_intervals[0][0], q_merged_intervals[-1][1]
assert isinstance(qstart, np.int), (q_merged_intervals, type(qstart))
assert isinstance(qend, np.int), (q_merged_intervals, type(qend))
# assert isinstance(qstart, np.int), (q_merged_intervals, type(qstart))
# assert isinstance(qend, np.int), (q_merged_intervals, type(qend))

hit_dict["query_start"], hit_dict["query_end"] = qstart, qend

if len(identical_positions) > q_aligned:
raise ValueError("Number of identical positions ({}) greater than number of aligned positions ({})!".format(
len(identical_positions), q_aligned))
raise ValueError(
"Number of identical positions ({}) greater than number of aligned positions ({})!\n{}\n{}".format(
len(identical_positions), q_aligned, q_intervals, q_merged_intervals))

if len(positives) > q_aligned:
raise ValueError("Number of identical positions ({}) greater than number of aligned positions ({})!".format(
Expand Down
1 change: 1 addition & 0 deletions Mikado/serializers/blast_serializer/xml_serialiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def serialize(self):
if self.single_thread is True or self.procs == 1:
cache = {"query": self.queries, "target": self.targets}
for filename in self.xml:
_ = xml_pickler(self.json_conf, filename, self.header, max_target_seqs=self.__max_target_seqs)
valid, _, exc = BlastOpener(filename).sniff(default_header=self.header)
if not valid:
self.logger.error(exc)
Expand Down
24 changes: 20 additions & 4 deletions Mikado/tests/test_blast_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_merging_1(self):
tot_length = 51
corr_merged = [(-10, 10), (20, 40)]
merged, tot_length = seri_blast_utils.merge(l, query_length=tot_length, offset=1)
self.assertEqual(merged, corr_merged)
self.assertTrue((merged == corr_merged))
self.assertEqual(tot_length, 10 - -10 +1 + 40 - 20 + 1)

def test_merging_2(self):
Expand All @@ -150,7 +150,7 @@ def test_merging_2(self):
else:
merged, length = seri_blast_utils.merge(l, offset=offset)
self.assertEqual(length, tot_length)
self.assertEqual(merged, l)
self.assertTrue((merged == l), (merged, l))

def test_various_merging(self):

Expand Down Expand Up @@ -178,7 +178,7 @@ def test_various_merging(self):
inp, out = valid[val]
with self.subTest(val=val, msg=valid[val]):
_ = seri_blast_utils.merge(inp)
self.assertEqual(out, _[0])
self.assertTrue((out == _[0]), (out, _[0]))

def test_included(self):

Expand All @@ -190,7 +190,23 @@ def test_included(self):
for val, out in cases.items():
with self.subTest(val=val, msg=cases[val]):
_ = seri_blast_utils.merge(list(val))
self.assertEqual(out, _[0])
self.assertTrue((out == _[0]), (out, _[0]))

def test_unordered(self):
cases = {
tuple([(10, 60), (40, 100), (200, 400)]): [(10, 100), (200, 400)],
tuple([(54, 1194), (110, 790), (950, 1052)]): [(54, 1194)],
tuple([(54, 1194), (110, 790), (950, 1052), (1200, 1400)]): [(54, 1194), (1200, 1400)]
}

from random import shuffle
for num in range(30):
for val, out in cases.items():
cval = list(val[:])
shuffle(cval)
with self.subTest(val=val, cval=cval, msg=cases[val]):
_ = seri_blast_utils.merge(list(cval))
self.assertTrue((out == _[0]), (out, _[0]))


if __name__ == '__main__':
Expand Down

0 comments on commit a1dbfaa

Please sign in to comment.