From 73218fc34a4f36dc7ec67ae285aa147a771f9b68 Mon Sep 17 00:00:00 2001 From: Luca Venturini Date: Tue, 3 Mar 2020 19:43:20 +0000 Subject: [PATCH] BROKEN; for #280: trying to implement the slow functions using NumPy. Some improvements but currently broken, and it could be better --- Mikado/parsers/blast_utils.py | 16 +-- Mikado/serializers/blast_serializer/utils.py | 127 ++++++++---------- .../blast_serializer/xml_serialiser.py | 49 ++++--- Mikado/subprograms/serialise.py | 2 +- 4 files changed, 87 insertions(+), 107 deletions(-) diff --git a/Mikado/parsers/blast_utils.py b/Mikado/parsers/blast_utils.py index 4d4337c9c..9b726cbd5 100644 --- a/Mikado/parsers/blast_utils.py +++ b/Mikado/parsers/blast_utils.py @@ -7,20 +7,12 @@ import os import subprocess import gzip -import multiprocessing import io -import collections -import time -import threading -import queue -import logging from . import HeaderError -from ..utilities.log_utils import create_null_logger -# from Bio.SearchIO.BlastIO.blast_xml import BlastXmlParser as xparser -from Bio.Blast.NCBIXML import parse as xparser -# import Bio.SearchIO -# import functools -# xparser = functools.partial(Bio.SearchIO.parse, format="blast-xml") +# from Bio.Blast.NCBIXML import parse as xparser +from Bio.SearchIO import parse as bio_parser +import functools +xparser = functools.partial(bio_parser, format="blast-xml") from ..utilities import overlap import xml.etree.ElementTree import numpy as np diff --git a/Mikado/serializers/blast_serializer/utils.py b/Mikado/serializers/blast_serializer/utils.py index fbb4aca15..e63d8af00 100644 --- a/Mikado/serializers/blast_serializer/utils.py +++ b/Mikado/serializers/blast_serializer/utils.py @@ -4,13 +4,15 @@ from ...parsers.blast_utils import merge import numpy as np -import functools + +valid_letters = 'ACDEFGHIKLMNPQRSTVWYBXZJUO' +letters = np.array(list(valid_letters)) __author__ = 'Luca Venturini' -valid_matches = set([chr(x) for x in range(65, 91)] + [chr(x) for x in range(97, 123)] + - ["|", "*"]) +# valid_matches = set([chr(x) for x in range(65, 91)] + [chr(x) for x in range(97, 123)] + +# ["|", "*"]) def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1): @@ -25,8 +27,7 @@ def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1): - If the position is a gap *for both*, insert a \ (backslash) :param hsp: An HSP object from Bio.Blast.NCBIXML - # :type hsp: Bio.Blast.Record.HSP - :type hsp: Bio.Blast.Record.HSP + :type hsp: Bio.SearchIO.HSP :param counter: a digit that indicates the priority of the HSP in the hit :return: hsp_dict, identical_positions, positives :rtype: (dict, set, set) @@ -38,78 +39,58 @@ def prepare_hsp(hsp, counter, qmultiplier=1, tmultiplier=1): hsp_dict["counter"] = counter + 1 hsp_dict["query_hsp_start"] = hsp.query_start hsp_dict["query_hsp_end"] = hsp.query_end - hsp_dict["query_frame"] = hsp.frame[0] - hsp_dict["target_hsp_start"] = hsp.sbjct_start - hsp_dict["target_hsp_end"] = hsp.sbjct_end - hsp_dict["target_frame"] = hsp.frame[1] - hsp_dict["hsp_identity"] = hsp.identities / hsp.align_length * 100 - hsp_dict["hsp_positives"] = hsp.positives / hsp.align_length * 100 + hsp_dict["query_frame"] = hsp.query_frame + hsp_dict["target_hsp_start"] = hsp.hit_start + hsp_dict["target_hsp_end"] = hsp.hit_end + hsp_dict["target_frame"] = hsp.hit_frame + hsp_dict["hsp_identity"] = hsp.ident_num / hsp.aln_span * 100 + hsp_dict["hsp_positives"] = hsp.pos_num / hsp.aln_span * 100 hsp_dict["match"] = match - hsp_dict["hsp_length"] = hsp.align_length - hsp_dict["hsp_bits"] = hsp.score - hsp_dict["hsp_evalue"] = hsp.expect + hsp_dict["hsp_length"] = hsp.aln_span + hsp_dict["hsp_bits"] = hsp.bitscore + hsp_dict["hsp_evalue"] = hsp.evalue return hsp_dict, identical_positions, positives +def _np_grouper(data): + return np.array(np.split(data, np.where(np.diff(data) != 1)[0] + 1)) + + def _prepare_aln_strings(hsp, qmultiplier=1, tmultiplier=1): """This private method calculates the identical positions, the positives, and a re-factored match line - starting from the HSP.""" - - # for query_aa, middle_aa, target_aa in zip(hsp.query, hsp.match, hsp.sbjct): - query_pos, target_pos = hsp.query_start - 1, hsp.sbjct_start - 1 - - def categoriser(middle_aa, query_aa, target_aa, qmultiplier, tmultiplier): - qpos = 0 - tpos = 0 - identical = set() - positives = set() - matched = "" - if query_aa == target_aa == "-": - matched = "\\" - elif query_aa == "-": - tpos = tmultiplier - if target_aa == "*": - matched = "*" - else: - matched = "-" - elif target_aa == "-": - qpos = qmultiplier - if query_aa == "*": - matched = "*" - else: - matched = "_" - elif middle_aa == " ": - matched = " " - qpos = qmultiplier - tpos += tmultiplier - elif middle_aa == "+" or middle_aa in valid_matches: - if middle_aa != "+": - identical = set(range(query_pos, query_pos + qmultiplier)) - positives = set(range(query_pos, query_pos + + qmultiplier)) - qpos = qmultiplier - tpos += tmultiplier - matched = middle_aa - return qpos, tpos, matched, identical, positives - - partial_categorizer = functools.partial(categoriser, - qmultiplier=qmultiplier, tmultiplier=tmultiplier) - - results = [partial_categorizer(middle_aa, query_aa, target_aa) for - middle_aa, query_aa, target_aa in zip(hsp.match, hsp.query, hsp.sbjct)] - - qposes, tposes, matches, identicals, posis = list(zip(*results)) - query_pos += sum(qposes) - target_pos += sum(tposes) - match = "".join(matches) - identical_positions = set.union(*identicals) - positives = set.union(*posis) - - assert query_pos <= hsp.query_end and target_pos <= hsp.sbjct_end, ((query_pos, hsp.query_end), - (target_pos, hsp.sbjct_end), - hsp.match, hsp.query, hsp.sbjct) - - return match, identical_positions, positives + starting from the HSP. + :type hsp: Bio.SearchIO.HSP + """ + + lett_array = np.array([ + list(str(hsp.query.seq)), + list(hsp.aln_annotation["similarity"]), + list(str(hsp.hit.seq))]) + + match = lett_array[1] + match[np.where(~((lett_array[1] == "+") | (np.isin(lett_array[1], letters))))] = " " + match[np.where( + (np.isin(lett_array[0], letters)) & + (np.isin(lett_array[2], letters)) & + (lett_array[0] != lett_array[2]) & + (lett_array[1] != "+"))] = "X" + match[np.where((lett_array[0] == "-") & (lett_array[2] == "*"))] = "*" + match[np.where((lett_array[0] == "-") & ~(lett_array[2] == "*"))] = "-" + match[np.where((lett_array[2] == "-") & (lett_array[0] == "*"))] = "*" + match[np.where((lett_array[2] == "-") & ~(lett_array[0] == "*"))] = "_" + + summer = np.array([[_] for _ in range(qmultiplier)]) + v = np.array([[1]] * qmultiplier) + identical_positions = np.where(np.isin(match, valid_letters)) * v + identical_positions = set(np.array( + [identical_positions[_] * 3 + summer for _ in range(identical_positions.shape[0])]).flatten()) + positives = np.where(~np.isin(match, np.array(["*", "-", "_", " "]))) * v + positives = set(np.array( + [positives[_] * 3 + summer for _ in range(positives.shape[0])]).flatten()) + str_match = "".join(match) + + return str_match, identical_positions, positives def prepare_hit(hit, query_id, target_id, **kwargs): @@ -118,7 +99,7 @@ def prepare_hit(hit, query_id, target_id, **kwargs): global_identity: the identity rate for the global hit *using the query perspective* :param hit: the hit to parse. - :type hit: Bio.Blast.Record.Alignment + :type hit: Bio.SearchIO.Hit :param query_id: the numeric ID of the query in the database. Necessary for serialisation. :type query_id: int @@ -172,7 +153,7 @@ def hsp_sorter(val): hsp_dict_list.append(hsp_dict) q_intervals.append((hsp.query_start, hsp.query_end)) # t_intervals.append((hsp.sbjct_start, hsp.sbjct_end)) - t_intervals.append((hsp.sbjct_start, hsp.sbjct_end)) + t_intervals.append((hsp.hit_start, hsp.hit_end)) q_merged_intervals, q_aligned = merge(q_intervals) assert isinstance(q_aligned, np.int), (q_merged_intervals, q_aligned, type(q_aligned)) @@ -192,7 +173,7 @@ def hsp_sorter(val): len(positives), q_aligned)) t_merged_intervals, t_aligned = merge(t_intervals) - hit_dict["target_aligned_length"] = min(t_aligned, hit.length) + hit_dict["target_aligned_length"] = min(t_aligned, hit.seq_len) hit_dict["target_start"] = t_merged_intervals[0][0] hit_dict["target_end"] = t_merged_intervals[-1][1] hit_dict["global_identity"] = len(identical_positions) * 100 / q_aligned diff --git a/Mikado/serializers/blast_serializer/xml_serialiser.py b/Mikado/serializers/blast_serializer/xml_serialiser.py index 006ed8eaa..817d63d94 100644 --- a/Mikado/serializers/blast_serializer/xml_serialiser.py +++ b/Mikado/serializers/blast_serializer/xml_serialiser.py @@ -85,9 +85,14 @@ def xml_pickler(json_conf, filename, default_header, try: with BlastOpener(filename) as opened: try: + qmult, tmult = None, None for query_counter, record in enumerate(opened, start=1): + if qmult is None: + qmult, tmult = XmlSerializer._get_multipliers(record) + hits, hsps, cache = objectify_record( - session, record, [], [], cache, max_target_seqs=max_target_seqs) + session, record, [], [], cache, max_target_seqs=max_target_seqs, + qmult=qmult, tmult=tmult) try: jhits = json.dumps(hits, number_mode=json.NM_NATIVE) @@ -610,18 +615,18 @@ def _get_query_for_blast(session: sqlalchemy.orm.session.Session, record, cache) """ This private method formats the name of the query recovered from the BLAST hit. It will cause an exception if the target is not present in the dictionary. - :param record: + :param record: Bio.SearchIO.Record :return: current_query (ID in the database), name """ - if record.query in cache: - return cache[record.query], record.query, cache - elif record.query.split()[0] in cache: - return cache[record.query.split()[0]], record.query.split()[0], cache + if record.id in cache: + return cache[record.id], record.id, cache + elif record.id.split()[0] in cache: + return cache[record.id.split()[0]], record.id.split()[0], cache else: got = session.query(Query).filter(sqlalchemy.or_( - Query.query_name == record.query, - Query.query_name == record.query.split()[0], + Query.query_name == record.id, + Query.query_name == record.id.split()[0], )).one() cache[got.query_name] = got.query_id return got.query_id, got.query_name, cache @@ -634,28 +639,30 @@ def _get_target_for_blast(session, alignment, cache): The method returns the index of the current target and and an updated target dictionary. :param alignment: an alignment child of a BLAST record object + :type alignment: Bio.SearchIO.Hit :return: current_target (ID in the database), targets """ if alignment.accession in cache: return cache[alignment.accession], cache - elif alignment.hit_id in cache: - return cache[alignment.hit_id], cache + elif alignment.id in cache: + return cache[alignment.id], cache else: got = session.query(Target).filter(sqlalchemy.or_( Target.target_name == alignment.accession, - Target.target_name == alignment.hit_id)).one() + Target.target_name == alignment.id)).one() cache[got.target_name] = got.target_id return got.target_id, cache def objectify_record(session, record, hits, hsps, cache, - max_target_seqs=10000, logger=create_null_logger()): + max_target_seqs=10000, logger=create_null_logger(), + qmult=1, tmult=1): """ Private method to serialise a single record into the DB. :param record: The BLAST record to load into the DB. - :type record: Bio.Blast.Record.Blast + :type record: Bio.SearchIO.QueryResult :param hits: Cache of hits to load into the DB. :type hits: list @@ -666,7 +673,7 @@ def objectify_record(session, record, hits, hsps, cache, :rtype: (list, list, dict) """ - if len(record.alignments) == 0: + if len(record.hits) == 0: return hits, hsps, cache current_query, name, cache["query"] = _get_query_for_blast(session, record, cache["query"]) @@ -675,18 +682,18 @@ def objectify_record(session, record, hits, hsps, cache, current_counter = 0 # for ccc, alignment in enumerate(record.alignments): - for ccc, alignment in enumerate(record.alignments): + for ccc, alignment in enumerate(record.hits): if ccc + 1 > max_target_seqs: break - logger.debug("Started the hit %s vs. %s", name, record.alignments[ccc].hit_id) + logger.debug("Started the hit %s vs. %s", name, record.hits[ccc].id) current_target, cache["target"] = _get_target_for_blast(session, alignment, cache["target"]) hit_dict_params = dict() (hit_dict_params["query_multiplier"], - hit_dict_params["target_multiplier"]) = XmlSerializer.get_multipliers(record) - hit_evalue = min(_.expect for _ in record.alignments[ccc].hsps) - hit_bs = max(_.score for _ in record.alignments[ccc].hsps) + hit_dict_params["target_multiplier"]) = (qmult, tmult) + hit_evalue = min(_.evalue for _ in record.hits[ccc].hsps) + hit_bs = max(_.bitscore for _ in record.hits[ccc].hsps) if current_evalue < hit_evalue: current_counter += 1 current_evalue = hit_evalue @@ -699,12 +706,12 @@ def objectify_record(session, record, hits, hsps, cache, try: hit, hit_hsps = prepare_hit(alignment, current_query, current_target, - query_length=record.query_length, + query_length=record.seq_len, **hit_dict_params) except InvalidHit as exc: logger.error(exc) continue - hit["query_aligned_length"] = min(record.query_length, hit["query_aligned_length"]) + hit["query_aligned_length"] = min(record.seq_len, hit["query_aligned_length"]) hits.append(hit) diff --git a/Mikado/subprograms/serialise.py b/Mikado/subprograms/serialise.py index 7efd7817e..4a99e0d32 100644 --- a/Mikado/subprograms/serialise.py +++ b/Mikado/subprograms/serialise.py @@ -438,7 +438,7 @@ def serialise_parser(): blast = parser.add_argument_group() blast.add_argument("--max_target_seqs", type=int, default=None, help="Maximum number of target sequences.") - blast.add_argument("--blast_targets", default=[], type=comma_split, + blast.add_argument("-bt", "--blast-targets", "--blast_targets", default=[], type=comma_split, help="Target sequences") blast.add_argument("--xml", type=str, help="""XML file(s) to parse. They can be provided in three ways: