Skip to content

Commit

Permalink
BROKEN; for #280: trying to implement the slow functions using NumPy.…
Browse files Browse the repository at this point in the history
… Some improvements but currently broken, and it could be better
  • Loading branch information
lucventurini committed Mar 4, 2020
1 parent 8ca4b86 commit 7e03df6
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 107 deletions.
16 changes: 4 additions & 12 deletions Mikado/parsers/blast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@
import os
import subprocess
import gzip
import multiprocessing
import io
import collections
import time
import threading
import queue
import logging
from . import HeaderError
from ..utilities.log_utils import create_null_logger
# from Bio.SearchIO.BlastIO.blast_xml import BlastXmlParser as xparser
from Bio.Blast.NCBIXML import parse as xparser
# import Bio.SearchIO
# import functools
# xparser = functools.partial(Bio.SearchIO.parse, format="blast-xml")
# from Bio.Blast.NCBIXML import parse as xparser
from Bio.SearchIO import parse as bio_parser
import functools
xparser = functools.partial(bio_parser, format="blast-xml")
from ..utilities import overlap
import xml.etree.ElementTree
import numpy as np
Expand Down
127 changes: 54 additions & 73 deletions Mikado/serializers/blast_serializer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

from ...parsers.blast_utils import merge
import numpy as np
import functools

valid_letters = 'ACDEFGHIKLMNPQRSTVWYBXZJUO'
letters = np.array(list(valid_letters))


__author__ = 'Luca Venturini'

valid_matches = set([chr(x) for x in range(65, 91)] + [chr(x) for x in range(97, 123)] +
["|", "*"])
# valid_matches = set([chr(x) for x in range(65, 91)] + [chr(x) for x in range(97, 123)] +
# ["|", "*"])


def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1):
Expand All @@ -25,8 +27,7 @@ def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1):
- If the position is a gap *for both*, insert a \ (backslash)
:param hsp: An HSP object from Bio.Blast.NCBIXML
# :type hsp: Bio.Blast.Record.HSP
:type hsp: Bio.Blast.Record.HSP
:type hsp: Bio.SearchIO.HSP
:param counter: a digit that indicates the priority of the HSP in the hit
:return: hsp_dict, identical_positions, positives
:rtype: (dict, set, set)
Expand All @@ -38,78 +39,58 @@ def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1):
hsp_dict["counter"] = counter + 1
hsp_dict["query_hsp_start"] = hsp.query_start
hsp_dict["query_hsp_end"] = hsp.query_end
hsp_dict["query_frame"] = hsp.frame[0]
hsp_dict["target_hsp_start"] = hsp.sbjct_start
hsp_dict["target_hsp_end"] = hsp.sbjct_end
hsp_dict["target_frame"] = hsp.frame[1]
hsp_dict["hsp_identity"] = hsp.identities / hsp.align_length * 100
hsp_dict["hsp_positives"] = hsp.positives / hsp.align_length * 100
hsp_dict["query_frame"] = hsp.query_frame
hsp_dict["target_hsp_start"] = hsp.hit_start
hsp_dict["target_hsp_end"] = hsp.hit_end
hsp_dict["target_frame"] = hsp.hit_frame
hsp_dict["hsp_identity"] = hsp.ident_num / hsp.aln_span * 100
hsp_dict["hsp_positives"] = hsp.pos_num / hsp.aln_span * 100
hsp_dict["match"] = match
hsp_dict["hsp_length"] = hsp.align_length
hsp_dict["hsp_bits"] = hsp.score
hsp_dict["hsp_evalue"] = hsp.expect
hsp_dict["hsp_length"] = hsp.aln_span
hsp_dict["hsp_bits"] = hsp.bitscore
hsp_dict["hsp_evalue"] = hsp.evalue
return hsp_dict, identical_positions, positives


def _np_grouper(data):
return np.array(np.split(data, np.where(np.diff(data) != 1)[0] + 1))


def _prepare_aln_strings(hsp, qmultiplier=1, tmultiplier=1):

"""This private method calculates the identical positions, the positives, and a re-factored match line
starting from the HSP."""

# for query_aa, middle_aa, target_aa in zip(hsp.query, hsp.match, hsp.sbjct):
query_pos, target_pos = hsp.query_start - 1, hsp.sbjct_start - 1

def categoriser(middle_aa, query_aa, target_aa, qmultiplier, tmultiplier):
qpos = 0
tpos = 0
identical = set()
positives = set()
matched = ""
if query_aa == target_aa == "-":
matched = "\\"
elif query_aa == "-":
tpos = tmultiplier
if target_aa == "*":
matched = "*"
else:
matched = "-"
elif target_aa == "-":
qpos = qmultiplier
if query_aa == "*":
matched = "*"
else:
matched = "_"
elif middle_aa == " ":
matched = " "
qpos = qmultiplier
tpos += tmultiplier
elif middle_aa == "+" or middle_aa in valid_matches:
if middle_aa != "+":
identical = set(range(query_pos, query_pos + qmultiplier))
positives = set(range(query_pos, query_pos + + qmultiplier))
qpos = qmultiplier
tpos += tmultiplier
matched = middle_aa
return qpos, tpos, matched, identical, positives

partial_categorizer = functools.partial(categoriser,
qmultiplier=qmultiplier, tmultiplier=tmultiplier)

results = [partial_categorizer(middle_aa, query_aa, target_aa) for
middle_aa, query_aa, target_aa in zip(hsp.match, hsp.query, hsp.sbjct)]

qposes, tposes, matches, identicals, posis = list(zip(*results))
query_pos += sum(qposes)
target_pos += sum(tposes)
match = "".join(matches)
identical_positions = set.union(*identicals)
positives = set.union(*posis)

assert query_pos <= hsp.query_end and target_pos <= hsp.sbjct_end, ((query_pos, hsp.query_end),
(target_pos, hsp.sbjct_end),
hsp.match, hsp.query, hsp.sbjct)

return match, identical_positions, positives
starting from the HSP.
:type hsp: Bio.SearchIO.HSP
"""

lett_array = np.array([
list(str(hsp.query.seq)),
list(hsp.aln_annotation["similarity"]),
list(str(hsp.hit.seq))])

match = lett_array[1]
match[np.where(~((lett_array[1] == "+") | (np.isin(lett_array[1], letters))))] = " "
match[np.where(
(np.isin(lett_array[0], letters)) &
(np.isin(lett_array[2], letters)) &
(lett_array[0] != lett_array[2]) &
(lett_array[1] != "+"))] = "X"
match[np.where((lett_array[0] == "-") & (lett_array[2] == "*"))] = "*"
match[np.where((lett_array[0] == "-") & ~(lett_array[2] == "*"))] = "-"
match[np.where((lett_array[2] == "-") & (lett_array[0] == "*"))] = "*"
match[np.where((lett_array[2] == "-") & ~(lett_array[0] == "*"))] = "_"

summer = np.array([[_] for _ in range(qmultiplier)])
v = np.array([[1]] * qmultiplier)
identical_positions = np.where(np.isin(match, valid_letters)) * v
identical_positions = set(np.array(
[identical_positions[_] * 3 + summer for _ in range(identical_positions.shape[0])]).flatten())
positives = np.where(~np.isin(match, np.array(["*", "-", "_", " "]))) * v
positives = set(np.array(
[positives[_] * 3 + summer for _ in range(positives.shape[0])]).flatten())
str_match = "".join(match)

return str_match, identical_positions, positives


def prepare_hit(hit, query_id, target_id, **kwargs):
Expand All @@ -118,7 +99,7 @@ def prepare_hit(hit, query_id, target_id, **kwargs):
global_identity: the identity rate for the global hit *using the query perspective*
:param hit: the hit to parse.
:type hit: Bio.Blast.Record.Alignment
:type hit: Bio.SearchIO.Hit
:param query_id: the numeric ID of the query in the database. Necessary for serialisation.
:type query_id: int
Expand Down Expand Up @@ -172,7 +153,7 @@ def hsp_sorter(val):
hsp_dict_list.append(hsp_dict)
q_intervals.append((hsp.query_start, hsp.query_end))
# t_intervals.append((hsp.sbjct_start, hsp.sbjct_end))
t_intervals.append((hsp.sbjct_start, hsp.sbjct_end))
t_intervals.append((hsp.hit_start, hsp.hit_end))

q_merged_intervals, q_aligned = merge(q_intervals)
assert isinstance(q_aligned, np.int), (q_merged_intervals, q_aligned, type(q_aligned))
Expand All @@ -192,7 +173,7 @@ def hsp_sorter(val):
len(positives), q_aligned))

t_merged_intervals, t_aligned = merge(t_intervals)
hit_dict["target_aligned_length"] = min(t_aligned, hit.length)
hit_dict["target_aligned_length"] = min(t_aligned, hit.seq_len)
hit_dict["target_start"] = t_merged_intervals[0][0]
hit_dict["target_end"] = t_merged_intervals[-1][1]
hit_dict["global_identity"] = len(identical_positions) * 100 / q_aligned
Expand Down
49 changes: 28 additions & 21 deletions Mikado/serializers/blast_serializer/xml_serialiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ def xml_pickler(json_conf, filename, default_header,
try:
with BlastOpener(filename) as opened:
try:
qmult, tmult = None, None
for query_counter, record in enumerate(opened, start=1):
if qmult is None:
qmult, tmult = XmlSerializer._get_multipliers(record)

hits, hsps, cache = objectify_record(
session, record, [], [], cache, max_target_seqs=max_target_seqs)
session, record, [], [], cache, max_target_seqs=max_target_seqs,
qmult=qmult, tmult=tmult)

try:
jhits = json.dumps(hits, number_mode=json.NM_NATIVE)
Expand Down Expand Up @@ -610,18 +615,18 @@ def _get_query_for_blast(session: sqlalchemy.orm.session.Session, record, cache)
""" This private method formats the name of the query
recovered from the BLAST hit. It will cause an exception if the target is not
present in the dictionary.
:param record:
:param record: Bio.SearchIO.Record
:return: current_query (ID in the database), name
"""

if record.query in cache:
return cache[record.query], record.query, cache
elif record.query.split()[0] in cache:
return cache[record.query.split()[0]], record.query.split()[0], cache
if record.id in cache:
return cache[record.id], record.id, cache
elif record.id.split()[0] in cache:
return cache[record.id.split()[0]], record.id.split()[0], cache
else:
got = session.query(Query).filter(sqlalchemy.or_(
Query.query_name == record.query,
Query.query_name == record.query.split()[0],
Query.query_name == record.id,
Query.query_name == record.id.split()[0],
)).one()
cache[got.query_name] = got.query_id
return got.query_id, got.query_name, cache
Expand All @@ -634,28 +639,30 @@ def _get_target_for_blast(session, alignment, cache):
The method returns the index of the current target and
and an updated target dictionary.
:param alignment: an alignment child of a BLAST record object
:type alignment: Bio.SearchIO.Hit
:return: current_target (ID in the database), targets
"""

if alignment.accession in cache:
return cache[alignment.accession], cache
elif alignment.hit_id in cache:
return cache[alignment.hit_id], cache
elif alignment.id in cache:
return cache[alignment.id], cache
else:
got = session.query(Target).filter(sqlalchemy.or_(
Target.target_name == alignment.accession,
Target.target_name == alignment.hit_id)).one()
Target.target_name == alignment.id)).one()
cache[got.target_name] = got.target_id
return got.target_id, cache


def objectify_record(session, record, hits, hsps, cache,
max_target_seqs=10000, logger=create_null_logger()):
max_target_seqs=10000, logger=create_null_logger(),
qmult=1, tmult=1):
"""
Private method to serialise a single record into the DB.
:param record: The BLAST record to load into the DB.
:type record: Bio.Blast.Record.Blast
:type record: Bio.SearchIO.QueryResult
:param hits: Cache of hits to load into the DB.
:type hits: list
Expand All @@ -666,7 +673,7 @@ def objectify_record(session, record, hits, hsps, cache,
:rtype: (list, list, dict)
"""

if len(record.alignments) == 0:
if len(record.hits) == 0:
return hits, hsps, cache

current_query, name, cache["query"] = _get_query_for_blast(session, record, cache["query"])
Expand All @@ -675,18 +682,18 @@ def objectify_record(session, record, hits, hsps, cache,
current_counter = 0

# for ccc, alignment in enumerate(record.alignments):
for ccc, alignment in enumerate(record.alignments):
for ccc, alignment in enumerate(record.hits):
if ccc + 1 > max_target_seqs:
break

logger.debug("Started the hit %s vs. %s", name, record.alignments[ccc].hit_id)
logger.debug("Started the hit %s vs. %s", name, record.hits[ccc].id)
current_target, cache["target"] = _get_target_for_blast(session, alignment, cache["target"])

hit_dict_params = dict()
(hit_dict_params["query_multiplier"],
hit_dict_params["target_multiplier"]) = XmlSerializer.get_multipliers(record)
hit_evalue = min(_.expect for _ in record.alignments[ccc].hsps)
hit_bs = max(_.score for _ in record.alignments[ccc].hsps)
hit_dict_params["target_multiplier"]) = (qmult, tmult)
hit_evalue = min(_.evalue for _ in record.hits[ccc].hsps)
hit_bs = max(_.bitscore for _ in record.hits[ccc].hsps)
if current_evalue < hit_evalue:
current_counter += 1
current_evalue = hit_evalue
Expand All @@ -699,12 +706,12 @@ def objectify_record(session, record, hits, hsps, cache,
try:
hit, hit_hsps = prepare_hit(alignment, current_query,
current_target,
query_length=record.query_length,
query_length=record.seq_len,
**hit_dict_params)
except InvalidHit as exc:
logger.error(exc)
continue
hit["query_aligned_length"] = min(record.query_length, hit["query_aligned_length"])
hit["query_aligned_length"] = min(record.seq_len, hit["query_aligned_length"])


hits.append(hit)
Expand Down
2 changes: 1 addition & 1 deletion Mikado/subprograms/serialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def serialise_parser():
blast = parser.add_argument_group()
blast.add_argument("--max_target_seqs", type=int, default=None,
help="Maximum number of target sequences.")
blast.add_argument("--blast_targets", default=[], type=comma_split,
blast.add_argument("-bt", "--blast-targets", "--blast_targets", default=[], type=comma_split,
help="Target sequences")
blast.add_argument("--xml", type=str, help="""XML file(s) to parse.
They can be provided in three ways:
Expand Down

0 comments on commit 7e03df6

Please sign in to comment.