From 6d357781b9f714c31ba3692648b5a5249513e703 Mon Sep 17 00:00:00 2001 From: Luca Venturini Date: Thu, 18 Oct 2018 19:50:17 +0100 Subject: [PATCH] Fixes for #137 and #138 --- CHANGELOG.md | 1 + Mikado/preparation/checking.py | 124 ++++++++- Mikado/preparation/prepare.py | 1 - Mikado/tests/prepare_misc_test.py | 258 ++++++++++++++++++ Mikado/tests/test_transcript_methods.py | 77 ++++++ Mikado/transcripts/transcript.py | 9 +- .../transcript_methods/finalizing.py | 145 +++++++--- .../transcript_methods/retrieval.py | 6 + 8 files changed, 562 insertions(+), 59 deletions(-) create mode 100644 Mikado/tests/prepare_misc_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5217298b7..be6b038a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Bugfixes and improvements: - [#132](https://github.com/lucventurini/mikado/issues/132),[#133](https://github.com/lucventurini/mikado/issues/133): Mikado will now evaluate the CDS of transcripts during Mikado prepare. - [#134](https://github.com/lucventurini/mikado/issues/134): when checking for potential Alternative Splicing Events (ASEs), now Mikado will check whether the CDS phases are in frame with each other. Moreover **Mikado will now calculate the CDS overlap percentage based on the primary transcript CDS length**, not the minimum CDS length between primary and candidate. Please note that the change **regarding the frame** also affects the monosublocus stage. Mikado still considers only the primary ORFs for the overlap. + - Solved a bug which led Mikado to recalculate the phases for each model during picking, potentially creating mistakes for models truncated at the 5' end. # Version 1.2.4 diff --git a/Mikado/preparation/checking.py b/Mikado/preparation/checking.py index e08a71960..71a4ca37e 100644 --- a/Mikado/preparation/checking.py +++ b/Mikado/preparation/checking.py @@ -1,13 +1,15 @@ import functools import multiprocessing +import multiprocessing.queues import os - import pyfaidx - from Mikado.transcripts.transcriptchecker import TranscriptChecker from .. import exceptions from ..loci import Transcript from ..utilities.log_utils import create_null_logger, create_queue_logger +import logging +import queue +import time __author__ = 'Luca Venturini' @@ -51,11 +53,14 @@ def create_transcript(lines, """ if logger is None: - logger = create_null_logger("checker") + logger = create_null_logger() - logger.debug("Starting with %s", lines["tid"]) + if "tid" not in lines: + logger.error("Lines datastore lacks the transcript ID. Exiting.") + return None try: + logger.debug("Starting with %s", lines["tid"]) transcript_line = Transcript() transcript_line.chrom = lines["chrom"] if "source" in lines: @@ -70,8 +75,15 @@ def create_transcript(lines, transcript_line.parent = lines["parent"] for feature in lines["features"]: - coords = [(_[0], _[1]) for _ in lines["features"][feature]] - phases = [_[2] for _ in lines["features"][feature]] + coords, phases = [], [] + for feat in lines["features"][feature]: + assert isinstance(feat, (list, tuple)) and 2 <= len(feat) <= 3, feat + coords.append((feat[0], feat[1])) + if len(feat) == 3 and feat[2] in (0, 1, 2, None): + phases.append(feat[2]) + else: + phases.append(None) + assert len(phases) == len(coords) transcript_line.add_exons(coords, features=feature, phases=phases) transcript_object = TranscriptChecker(transcript_line, @@ -100,8 +112,6 @@ def create_transcript(lines, logger.exception(exc) transcript_object = None - logger.debug("Finished with %s", lines["tid"]) - return transcript_object @@ -124,17 +134,26 @@ def __init__(self, ): super().__init__() - self.__identifier = identifier + self.__identifier = "" + self.__set_identifier(identifier) # self.strand_specific = strand_specific - self.canonical = canonical_splices + self.__canonical = [] + self.__set_canonical(canonical_splices) + self.__log_level = "DEBUG" self.log_level = log_level - self.logger = None - self.logging_queue = logging_queue + self.logger = None # This gets populated by the create_queue_logger function below + self.__logging_queue = None + self.__set_logging_queue(logging_queue) self.name = "Checker-{0}".format(self.identifier) - create_queue_logger(self) + try: + create_queue_logger(self) + except AttributeError as exc: + raise AttributeError(exc) + self.__lenient = False self.lenient = lenient self.__fasta = fasta - self.submission_queue = submission_queue + self.__submission_queue = None + self.__set_submission_queue(submission_queue) self.fasta = pyfaidx.Fasta(self.__fasta) self.fasta_out = os.path.join(tmpdir, "{0}-{1}".format( fasta_out, self.identifier @@ -142,7 +161,6 @@ def __init__(self, self.gtf_out = os.path.join(tmpdir, "{0}-{1}".format( gtf_out, self.identifier )) - self.logger.debug(self.canonical) def run(self): @@ -154,7 +172,12 @@ def run(self): fasta_out = open(self.fasta_out, "w") gtf_out = open(self.gtf_out, "w") + self.logger.debug("Starting %s", self.name) + self.logger.debug("Created output FASTA {self.fasta_out} and GTF {self.gtf_out}".format(**locals())) + time.sleep(0.1) + self.logger.debug(self.canonical) + __printed = 0 while True: lines, start, end, counter = self.submission_queue.get() if lines == "EXIT": @@ -176,13 +199,23 @@ def run(self): continue else: self.logger.debug("Printing %s", lines["tid"]) + __printed += 1 print("\n".join(["{0}/{1}".format(counter, line) for line in transcript.format("gtf").split("\n")]), file=gtf_out) print("\n".join(["{0}/{1}".format(counter, line) for line in transcript.fasta.split("\n")]), file=fasta_out) + time.sleep(0.1) + fasta_out.flush() fasta_out.close() + gtf_out.flush() gtf_out.close() + if __printed > 0: + self.logger.debug("Size of FASTA out and GTF out: %s, %s", + os.stat(fasta_out.name).st_size, os.stat(gtf_out.name).st_size) + assert os.stat(gtf_out.name).st_size > 0 + assert os.stat(fasta_out.name).st_size > 0 + time.sleep(0.1) return def __getstate__(self): @@ -200,3 +233,64 @@ def __setstate__(self, state): @property def identifier(self): return self.__identifier + + def __set_identifier(self, identifier): + + if identifier is None: + raise ValueError("The identifier must be defined!") + self.__identifier = str(identifier) + + @property + def log_level(self): + return self.__log_level + + @log_level.setter + def log_level(self, log_level): + _ = logging._checkLevel(log_level) + self.__log_level = log_level + + @property + def lenient(self): + return self.__lenient + + @lenient.setter + def lenient(self, lenient): + if lenient not in (False, True): + raise ValueError("Invalid lenient value: {}".format(lenient)) + self.__lenient = lenient + + @property + def submission_queue(self): + return self.__submission_queue + + def __set_submission_queue(self, submission): + if not isinstance(submission, (multiprocessing.queues.Queue, queue.Queue)): + raise ValueError("Invalid queue object: {}".format(type(submission))) + self.__submission_queue = submission + + @property + def logging_queue(self): + return self.__logging_queue + + def __set_logging_queue(self, logging_queue): + if not isinstance(logging_queue, (multiprocessing.queues.Queue, queue.Queue)): + raise ValueError("Invalid queue object: {}".format(type(logging_queue))) + self.__logging_queue = logging_queue + + @property + def canonical(self): + return self.__canonical + + def __set_canonical(self, canonical): + if not isinstance(canonical, (tuple, list)): + raise TypeError("Canonical splices should be lists or tuples") + + if len(canonical) == 0: + raise ValueError("The list of canonical splices cannot be empty!") + + for el in canonical: + if (len(el) != 2 or (not (isinstance(el[0], str) and len(el[0]) == 2) or + not (isinstance(el[1], str) and len(el[1]) == 2 ))): + raise ValueError("Invalid splicing pattern!") + + self.__canonical = canonical diff --git a/Mikado/preparation/prepare.py b/Mikado/preparation/prepare.py index 1beba5203..ca74716f2 100644 --- a/Mikado/preparation/prepare.py +++ b/Mikado/preparation/prepare.py @@ -419,7 +419,6 @@ def prepare(args, logger): args.json_conf["prepare"]["files"]["output_dir"], args.json_conf["prepare"]["files"]["out"]), 'w') - logger.info("Loading reference file") args.json_conf["reference"]["genome"] = pyfaidx.Fasta(args.json_conf["reference"]["genome"]) diff --git a/Mikado/tests/prepare_misc_test.py b/Mikado/tests/prepare_misc_test.py new file mode 100644 index 000000000..b953d23fc --- /dev/null +++ b/Mikado/tests/prepare_misc_test.py @@ -0,0 +1,258 @@ +import unittest +import Mikado.preparation.checking +import Mikado +import multiprocessing as mp +import pkg_resources +import tempfile +import logging +import logging.handlers +import gzip +import pickle +import os +import time +import threading +import pyfaidx +import re + + +class ProcRunner(threading.Thread): + + def __init__(self, function: [mp.Process], *args, **kwargs): + + self.__function = function + self.args = args + self.kwargs = kwargs + self._func = self.__function(*self.args, **self.kwargs) + super().__init__() + + def run(self): + self._func.run() + + @property + def func(self): + return self._func + + def join(self, *args, **kwargs): + if self.func._popen is not None: + self.func.join() + self.func.terminate() + super().join(timeout=0.1) + + +class MiscTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.fasta = pkg_resources.resource_filename("Mikado.tests", "chr5.fas.gz") + cls.fasta_temp = tempfile.NamedTemporaryFile(suffix=".fa") + with gzip.open(cls.fasta) as ffile: + cls.fasta_temp.write(ffile.read()) + cls.fasta_temp.flush() + + @staticmethod + def create_logger(name): + logging_queue = mp.JoinableQueue(-1) + log_queue_handler = logging.handlers.QueueHandler(logging_queue) + log_queue_handler.setLevel(logging.DEBUG) + + logger = Mikado.utilities.log_utils.create_default_logger(name, level="WARNING") + logger.propagate = False + listener = logging.handlers.QueueListener(logging_queue, logger) + listener.propagate = False + listener.start() + return logger, listener, logging_queue + + def setUp(self): + # Create the queues for logging and submission + self.submission_queue = mp.JoinableQueue() + self.fasta_out = "temporary.fasta" + self.gtf_out = "temporary.gtf" + + @unittest.skip + def test_normal(self): + # TODO: this test is creating problems due to threading errors. + logger, listener, logging_queue = self.create_logger("test_normal") + + with self.assertLogs(logger=logger, level="DEBUG") as cmo: + # FASTA out and GTF out are just the file names, without the temporary directory + # Moreover they will be complemented by the identifier! + + proc = ProcRunner(Mikado.preparation.checking.CheckingProcess, + self.submission_queue, + logging_queue, + fasta=self.fasta_temp.name, + identifier=0, + fasta_out=self.fasta_out, + gtf_out=self.gtf_out, + tmpdir=tempfile.gettempdir(), + log_level="DEBUG") + proc.start() + time.sleep(0.1) # Necessary otherwise the check might be too fast for the FileSystem + self.assertEqual(proc.func.fasta_out, os.path.join(tempfile.gettempdir(), self.fasta_out + "-0")) + self.assertTrue(os.path.exists(proc.func.fasta_out), proc.func.fasta_out) + self.assertEqual(proc.func.gtf_out, os.path.join(tempfile.gettempdir(), self.gtf_out + "-0")) + self.assertTrue(os.path.exists(proc.func.gtf_out), proc.func.gtf_out) + self.submission_queue.put(("EXIT", None, None, None)) + time.sleep(0.1) + proc.join() + + os.remove(proc.func.fasta_out) + os.remove(proc.func.gtf_out) + + self.maxDiff = 10000 + self.assertEqual(cmo.output, [ + "DEBUG:Checker-0:Starting Checker-0", + "DEBUG:Checker-0:Created output FASTA {} and GTF {}".format(proc.func.fasta_out, proc.func.gtf_out), + "DEBUG:Checker-0:(('GT', 'AG'), ('GC', 'AG'), ('AT', 'AC'))", + "DEBUG:Checker-0:Finished for Checker-0"]) + + self.assertIsInstance(proc.func, mp.Process) + + with self.assertRaises(TypeError): + picked = pickle.dumps(proc.func) + + # def test_logger_creater(self): + + def test_wrong_initialisation(self): + + logger, listener, logging_queue = self.create_logger("test_wrong_initialisation") + + kwds = {"submission_queue": self.submission_queue, + "logging_queue": logging_queue, + "identifier": 0, + "fasta_out": self.fasta_out, + "gtf_out": self.gtf_out, + "tmpdir": tempfile.gettempdir(), + "fasta": self.fasta_temp.name, + "log_level": "WARNING" + } + + for key in ["submission_queue", "logging_queue", "identifier", "lenient"]: + with self.subTest(key=key): + _kwds = kwds.copy() + _kwds[key] = None + with self.assertRaises(ValueError): + Mikado.preparation.checking.CheckingProcess(**_kwds) + + with self.subTest(key="fasta"): + _kwds = kwds.copy() + _kwds["fasta"] = None + with self.assertRaises(AttributeError): + Mikado.preparation.checking.CheckingProcess(**_kwds) + + for tentative in [None, [], [("A", "G")], [("AG", bytes("GT", encoding="ascii"))]]: + with self.subTest(tentative=tentative): + _kwds = kwds.copy() + _kwds["canonical_splices"] = tentative + if tentative is None: + with self.assertRaises(TypeError): + Mikado.preparation.checking.CheckingProcess(**_kwds) + else: + with self.assertRaises(ValueError): + Mikado.preparation.checking.CheckingProcess(**_kwds) + + _kwds = kwds.copy() + _kwds["canonical_splices"] = [("AG", "GT")] + # just test it does not raise + _ = Mikado.preparation.checking.CheckingProcess(**_kwds) + + def test_example_model(self): + + fasta = pyfaidx.Fasta(self.fasta_temp.name) + lines = dict() + lines["chrom"] = "Chr5" + lines["strand"] = "+" + lines["start"] = 208937 + lines["end"] = 210445 + lines["attributes"] = dict() + lines["tid"], lines["parent"] = "AT5G01530.0", "AT5G01530" + lines["features"] = dict() + lines["features"]["exon"] = [(208937, 209593), (209881, 210445)] + seq = str(fasta[lines["chrom"]][lines["start"] - 1:lines["end"]]) + + logger, listener, logging_queue = self.create_logger("test_example_model") + + res = Mikado.preparation.checking.create_transcript(lines, seq, lines["start"], lines["end"], + logger=logger) + listener.stop() + self.assertIsInstance(res, Mikado.transcripts.TranscriptChecker) + + for kwd in lines.keys(): + if kwd in ["end", "start"]: + continue + with self.subTest(kwd=kwd, msg="Testing key {}".format(kwd)): + _lines = lines.copy() + del _lines[kwd] + with self.assertLogs("null", level="DEBUG"): + res = Mikado.preparation.checking.create_transcript(_lines, seq, lines["start"], lines["end"]) + self.assertIs(res, None) + + _lines = lines.copy() + _lines["strand"] = "-" + with self.subTest(msg="Testing an invalid strand"): + with self.assertLogs("null", level="INFO") as cm: + res = Mikado.preparation.checking.create_transcript(_lines, seq, lines["start"], lines["end"], + strand_specific=True) + self.assertIsInstance(res, Mikado.transcripts.TranscriptChecker) + self.assertIn("WARNING:null:Transcript AT5G01530.0 has been assigned to the wrong strand, reversing it.", + cm.output) + + @unittest.skip + def test_example_model_through_process(self): + + logger, listener, logging_queue = self.create_logger("test_example_model_through_process") + logger.setLevel("DEBUG") + with self.assertLogs(logger=logger, level="DEBUG") as cmo: + # FASTA out and GTF out are just the file names, without the temporary directory + # Moreover they will be complemented by the identifier! + + proc = ProcRunner(Mikado.preparation.checking.CheckingProcess, + self.submission_queue, + logging_queue, + fasta=self.fasta_temp.name, + identifier=logger.name, + fasta_out=self.fasta_out, + gtf_out=self.gtf_out, + tmpdir=tempfile.gettempdir(), + log_level="DEBUG") + lines = dict() + lines["chrom"] = "Chr5" + lines["strand"] = "+" + lines["start"] = 208937 + lines["end"] = 210445 + lines["attributes"] = dict() + lines["tid"], lines["parent"] = "AT5G01530.0", "AT5G01530" + lines["features"] = dict() + lines["features"]["exon"] = [(208937, 209593), (209881, 210445)] + lines["strand_specific"] = True + self.submission_queue.put((lines, lines["start"], lines["end"], 0)) + self.submission_queue.put(("EXIT", None, None, None)) + proc.start() + proc.join() + time.sleep(0.5) + self.assertTrue(os.stat(proc.func.fasta_out).st_size > 0, proc.func.fasta_out) + fasta_lines = [] + with open(proc.func.fasta_out) as f_out: + for line in f_out: + line = line.rstrip() + line = re.sub("0/", "", line) + fasta_lines.append(line) + + fasta = pyfaidx.Fasta(self.fasta_temp.name) + seq = str(fasta[lines["chrom"]][lines["start"] - 1:lines["end"]]) + res = Mikado.preparation.checking.create_transcript(lines, seq, lines["start"], lines["end"]) + self.assertTrue(len(res.cdna), (209593 - 208937 + 1) + (210445 - 209881 + 1)) + + with tempfile.NamedTemporaryFile(suffix="fa", delete=True, mode="wt") as faix: + assert len(fasta_lines) > 0 + print(*fasta_lines, file=faix, end="") + faix.flush() + fa = pyfaidx.Fasta(faix.name) + self.assertEqual(list(fa.keys()), ["AT5G01530.0"]) + self.assertEqual(str(fa["AT5G01530.0"]), str(res.cdna)) + self.assertEqual(len(str(fa["AT5G01530.0"])), res.cdna_length) + + os.remove(faix.name + ".fai") + + print(cmo.output) + listener.stop() diff --git a/Mikado/tests/test_transcript_methods.py b/Mikado/tests/test_transcript_methods.py index eda7f86c5..32d3c00de 100644 --- a/Mikado/tests/test_transcript_methods.py +++ b/Mikado/tests/test_transcript_methods.py @@ -9,6 +9,7 @@ from Mikado.parsers.GTF import GtfLine from Mikado.parsers.GFF import GffLine from Mikado.transcripts.transcript_methods import retrieval +from Mikado.utilities.log_utils import create_default_logger class WrongLoadedOrf(unittest.TestCase): @@ -512,6 +513,82 @@ def test_comparisons(self): self.assertEqual(t1, t1) self.assertEqual(self.t1, self.t1) + def test_to_and_from_dict(self): + + d = self.t1.as_dict() + self.assertIsInstance(d, dict) + self.assertIn("chrom", d) + t1 = Transcript() + t1.load_dict(d) + self.assertIs(t1.is_coding, True) + self.assertEqual(t1, self.t1) + + def test_to_and_from_dict_truncated(self): + + gtf = GtfLine("""Chr5\tfoo\tCDS\t100\t200\t.\t+\t2\ttranscript_id "test.1"; gene_id "test.2";""") + logger = create_default_logger("test_to_and_from_dict_truncated", level="WARNING") + t = Transcript(gtf) + t.finalize() + self.assertIs(t.is_coding, True) + self.assertEqual(t.selected_internal_orf, [("CDS", (100, 200), 2), ("exon", (100, 200))]) + + d = t.as_dict() + self.assertEqual(d["orfs"]["0"], [["CDS", [100, 200], 2], ["exon", [100, 200]]]) + + new_t = Transcript(logger=logger) + new_t.load_dict(d) + self.assertEqual(t.selected_internal_orf, new_t.selected_internal_orf) + + def test_to_and_from_dict_multiple_truncated(self): + + gtf = GtfLine("""Chr5\texon\texon\t1\t1200\t.\t+\t.\ttranscript_id "test.1"; gene_id "test.2";""") + logger = create_default_logger("tdmt", level="WARNING") + t = Transcript(gtf, logger=logger) + t.finalize() + orf = "test.1\t0\t1200\tID=test.1.orf1;coding=True;phase=2\t0\t+\t0\t320\t.\t1\t1200\t0" + self.assertEqual(len(orf.split("\t")), 12) + b = BED12(orf, transcriptomic=True, coding=True) + self.assertFalse(b.header) + self.assertFalse(b.invalid, b.invalid_reason) + orf2 = "test.1\t0\t1200\tID=test.1.orf1;coding=True;phase=0\t0\t+\t400\t1000\t.\t1\t1200\t0" + b2 = BED12(orf2, transcriptomic=True, coding=True) + self.assertFalse(b2.header) + self.assertFalse(b2.invalid, b2.invalid_reason) + t.load_orfs([b, b2]) + self.assertTrue(t.is_coding) + self.assertEqual(t.number_internal_orfs, 2) + with self.subTest(): + new_t = Transcript(logger=logger) + d = t.as_dict() + self.assertEqual(len(d["orfs"]), 2) + new_t.load_dict(d) + self.assertTrue(new_t.is_coding) + self.assertEqual(new_t.number_internal_orfs, 2) + self.assertEqual(new_t.combined_utr, [(321, 400), (1001, 1200)]) + + # Now let us test out what happens with three ORFs + + with self.subTest(): + new_t = Transcript(logger=logger) + d = t.as_dict() + d["orfs"]["2"] = [("exon", (1, 1200)), ("UTR", (1, 1099)), ("CDS", [1100, 1200], 0)] + new_t.load_dict(d) + self.assertTrue(new_t.is_coding) + self.assertEqual(new_t.number_internal_orfs, 3) + self.assertEqual(new_t.combined_utr, [(321, 400), (1001, 1099)]) + + # Now let us check what happens with the addition of an incompatible ORF + new_t = Transcript() + d = t.as_dict() + d["orfs"]["2"] = [("CDS", (900, 1200), 0)] + with self.assertLogs(logger=new_t.logger, level="DEBUG") as cmo: + new_t.load_dict(d) + self.assertFalse(new_t.is_coding) + self.assertIn( + "WARNING:{}:Error while inferring the UTR for a transcript with multiple ORFs: overlapping CDS found.".format( + new_t.logger.name + ), + cmo.output) if __name__ == '__main__': unittest.main() diff --git a/Mikado/transcripts/transcript.py b/Mikado/transcripts/transcript.py index 90a7da0d8..bead08d9b 100644 --- a/Mikado/transcripts/transcript.py +++ b/Mikado/transcripts/transcript.py @@ -1213,21 +1213,28 @@ def load_dict(self, state): self.exons.append(tuple(exon)) self.internal_orfs = [] + self.logger.debug("Starting to load the ORFs for %s", self.id) try: for orf in iter(state["orfs"][_] for _ in sorted(state["orfs"])): neworf = [] for segment in orf: - if segment[0] == "CDS": + + # if segment[0] == "CDS": + if len(segment) == 3: + assert segment[0] == "CDS" + new_segment = (segment[0], tuple(segment[1]), int(segment[2])) self.combined_cds.append(tuple(segment[1])) else: + assert segment[0] != "CDS" new_segment = (segment[0], tuple(segment[1])) neworf.append(new_segment) self.internal_orfs.append(neworf) + self.logger.debug("ORFs: %s", " ".join(str(_) for _ in self.internal_orfs)) self.selected_internal_orf_index = state["selected_orf"] except (ValueError, IndexError): raise CorruptIndex("Invalid values for ORFs of {}".format(self.id)) diff --git a/Mikado/transcripts/transcript_methods/finalizing.py b/Mikado/transcripts/transcript_methods/finalizing.py index 39b81380a..8c2268ad8 100644 --- a/Mikado/transcripts/transcript_methods/finalizing.py +++ b/Mikado/transcripts/transcript_methods/finalizing.py @@ -3,7 +3,7 @@ e.g. reliability of the CDS/UTR, sanity of borders, etc. """ -from Mikado.utilities.intervaltree import Interval +from Mikado.utilities.intervaltree import Interval, IntervalTree import operator from sys import intern from Mikado.exceptions import InvalidCDS, InvalidTranscript @@ -108,6 +108,7 @@ def __basic_final_checks(transcript): transcript.logger.exception(exc) raise exc + def _check_cdna_vs_utr(transcript): """ @@ -122,46 +123,76 @@ def _check_cdna_vs_utr(transcript): transcript.logger.debug("%s is non coding, returning", transcript.id) return assert transcript.combined_cds != [] - transcript.logger.debug("Recalculating the UTR for %s", transcript.id) + transcript.logger.debug("Recalculating the UTR for %s. %s", transcript.id, + transcript.combined_cds) transcript.combined_utr = [] # Reset transcript.combined_cds = sorted(transcript.combined_cds, key=operator.itemgetter(0, 1)) + + combined_cds = IntervalTree.from_tuples(transcript.combined_cds) + orfs = [IntervalTree.from_tuples([_[1] for _ in orf if _[0] == "CDS"]) for orf in transcript.internal_orfs] + assert isinstance(combined_cds, IntervalTree) + for exon in transcript.exons: - assert isinstance(exon, tuple) - if exon in transcript.combined_cds: - continue - # The end of the exon is before the first ORF start - # or the start is after the last ORF segment: UTR segment - elif (exon[1] < transcript.combined_cds[0][0] or - exon[0] > transcript.combined_cds[-1][1]): + assert isinstance(exon, tuple), type(exon) + found = combined_cds.find(exon[0], exon[1]) + if len(found) == 0: + # Exon completely noncoding transcript.combined_utr.append(exon) - - # The last base of the exon is the first ORF base - elif (exon[0] < transcript.combined_cds[0][0] and - exon[1] == transcript.combined_cds[0][1]): - transcript.combined_utr.append(tuple([ - exon[0], transcript.combined_cds[0][0] - 1])) - # The first base of the exon is the first base of the last ORF segment: - # UTR after - elif (exon[1] > transcript.combined_cds[-1][1] and - exon[0] == transcript.combined_cds[-1][0]): - transcript.combined_utr.append(tuple([ - transcript.combined_cds[-1][1] + 1, exon[1]])) - else: - # If the ORF is contained inside a single exon, with UTR - # at both sites, then we create the two UTR segments - if len(transcript.combined_cds) == 1: - transcript.combined_utr.append(tuple([ - exon[0], transcript.combined_cds[0][0] - 1])) - transcript.combined_utr.append(tuple([ - transcript.combined_cds[-1][1] + 1, exon[1]])) + elif len(found) == 1: + found = found[0] + if found.start == exon[0] and found.end == exon[1]: + # The exon is completely coding + continue else: - # This means there is an INTERNAL UTR region between - # two CDS segments: something is clearly wrong! + # I have to find all the regions of the exon which are not coding + before = None + after = None + if found.start > exon[0]: + before = (exon[0], max(found.start - 1, exon[0])) + transcript.combined_utr.append(before) + if found.end < exon[1]: + after = (min(found.end + 1, exon[1]), exon[1]) + transcript.combined_utr.append(after) + + assert before or after, (exon, found) + else: + # The exon is overlapping *two* different CDS segments! This is valid *only* if there are multiple ORFs + if len(found) > len(transcript.internal_orfs): raise InvalidCDS( - "Error while inferring the UTR", - exon, transcript.id, - transcript.exons, transcript.combined_cds) + "Found in {} an exon ({}) which is overlapping with more CDS segments than there are ORFs.".format( + transcript.id, exon + )) + # Now we have to check for each internal ORF that things are OK + for orf in orfs: + orf_found = orf.find(exon[0], exon[1]) + if len(orf_found) > 1: + raise InvalidCDS( + "Found in {} an exon ({}) which is overlapping with more CDS segments in a single ORF.".format( + transcript.id, exon + )) + # If we are here, it means that the internal UTR is legit. We should now add the untranslated regions + # to the store. + transcript.logger.debug("Starting to find the UTRs for %s", exon) + found = sorted(found) + utrs = [] + for pos, interval in enumerate(found): + if pos == len(found) - 1: + if exon[1] > interval.end: + utrs.append((min(exon[1], interval.end + 1), exon[1])) + continue + if pos == 0 and exon[0] < interval.start: + utrs.append((exon[0], max(exon[0], interval.start - 1))) + next_interval = found[pos + 1] + if not (interval.end + 1 <= next_interval.start - 1): + raise InvalidCDS( + "Error while inferring the UTR for a transcript with multiple ORFs: overlapping CDS found.") + utrs.append((interval.end + 1, next_interval.start - 1)) + assert utrs, found + utr_sum = sum([_[1] - _[0] + 1 for _ in utrs]) + cds_sum = sum(_.end - _.start + 1 for _ in found) + assert utr_sum + cds_sum == exon[1] - exon[0] + 1, (utr_sum, cds_sum, exon[1] - exon[0] + 1, utrs, found) + transcript.combined_utr.extend(utrs) # If no CDS and no UTR are present, all good equality_one = (transcript.combined_cds_length == transcript.combined_utr_length == 0) @@ -351,8 +382,8 @@ def __check_internal_orf(transcript, index): exons = sorted(transcript.exons, reverse=(transcript.strand == "-")) - coding = sorted([(_[0], _[1]) for _ in orf if _[0] == "CDS"], - key=operator.itemgetter(1)) + coding = sorted([_ for _ in orf if _[0] == "CDS"], key=operator.itemgetter(1)) + transcript.logger.debug("ORF for %s: %s", transcript.id, coding) if not coding: raise InvalidCDS("No ORF for {}, index {}!".format(transcript.id, index)) @@ -404,16 +435,33 @@ def __check_internal_orf(transcript, index): del before, after - if index == 0 and transcript.phases: + phase_orf = [] + for _ in coding: + if len(_) == 3: + if _[2] not in (None, 0, 1, 2): + raise ValueError("Invalid phase value for {}".format(transcript.id)) + phase_orf.append(_[2]) + elif len(_) == 2: + continue + else: + raise ValueError("Invalid CDS fragment: {}".format(_)) + + if len(phase_orf) != 0 and len(phase_orf) != len(coding): + transcript.logger.warning("Invalid phases for %s. Resetting.", transcript.id) + phase_orf = [] + + if not phase_orf and transcript.phases: phases_keys = sorted(transcript.phases.keys(), reverse=(transcript.strand == "-")) phase_orf = [transcript.phases[_] for _ in phases_keys] # Calculating the complement of the phase so that # previous = (3 - phase_orf[0]) % 3 previous = phase_orf[0] # transcript.logger.warning(previous) - elif index == 0 and transcript._first_phase is not None: + elif not phase_orf and transcript._first_phase is not None: previous = transcript._first_phase phase_orf = [] + elif phase_orf: + previous = phase_orf[0] else: phase_orf = [] for segment in sorted(orf, key=operator.itemgetter(1), reverse=(transcript.strand == "-")): @@ -501,7 +549,17 @@ def __check_phase_correctness(transcript): transcript.segments = [("exon", tuple([e[0], e[1]])) for e in transcript.exons] # Define CDS - transcript.segments.extend([("CDS", tuple([c[0], c[1]])) + if len(transcript.internal_orfs) > 0: + for orf in transcript.internal_orfs: + for segment in orf: + if segment[0] == "exon": + continue + elif segment[0] == "UTR": + transcript.segments.append(("UTR", (segment[1][0], segment[1][1]))) + elif segment[0] == "CDS": + transcript.segments.append(("CDS", (segment[1][0], segment[1][1]))) + else: + transcript.segments.extend([("CDS", tuple([c[0], c[1]])) for c in transcript.combined_cds]) # Define UTR segments transcript.segments.extend([("UTR", tuple([u[0], u[1]])) @@ -509,7 +567,7 @@ def __check_phase_correctness(transcript): # Mix and sort transcript.segments = sorted(transcript.segments, key=operator.itemgetter(1, 0)) # Add to the store as a single entity - if any(_[0] == "CDS" for _ in transcript.segments): + if not transcript.internal_orfs and any(_[0] == "CDS" for _ in transcript.segments): transcript.internal_orfs = [transcript.segments] else: transcript.selected_internal_orf_index = None @@ -517,13 +575,15 @@ def __check_phase_correctness(transcript): exception = AssertionError("No internal ORF for {}".format(transcript.id)) transcript.logger.exception(exception) raise exception + else: + transcript.logger.debug("Segments and ORFs defined for %s", transcript.id) transcript.logger.debug("{} has {} internal ORF{}".format( transcript.id, len(transcript.internal_orfs), "s" if len(transcript.internal_orfs) > 1 else "")) for orf_index in range(len(transcript.internal_orfs)): transcript.logger.debug("ORF #%d for %s: %s", - orf_index, transcript.id, transcript.phases) + orf_index, transcript.id, transcript.internal_orfs[orf_index]) try: transcript = __check_internal_orf(transcript, orf_index) @@ -601,11 +661,12 @@ def finalize(transcript): try: _check_cdna_vs_utr(transcript) - except InvalidCDS: + except InvalidCDS as exc: if transcript.combined_cds: transcript.logger.warning( "Possible faulty UTR annotation for %s, trying to recalculate it.", transcript.id) + transcript.logger.warning(exc) transcript.combined_utr = [] try: _check_cdna_vs_utr(transcript) diff --git a/Mikado/transcripts/transcript_methods/retrieval.py b/Mikado/transcripts/transcript_methods/retrieval.py index c9b271893..b0f5ed420 100644 --- a/Mikado/transcripts/transcript_methods/retrieval.py +++ b/Mikado/transcripts/transcript_methods/retrieval.py @@ -444,6 +444,12 @@ def find_overlapping_cds(transcript, candidates: list) -> list: candidates = list(corf for corf in candidates if ( corf.invalid is False and corf.transcriptomic is True)) + ids = set(_.name for _ in candidates) + if len(ids) < len(candidates): + transcript.logger.debug("Colliding IDs found for the ORFs. Redefining them.") + for pos in range(1, len(candidates) + 1): + candidates[pos - 1].name = "{transcript.id}.orf{pos}".format(**locals()) + transcript.logger.debug("{0} filtered ORFs for {1}".format(len(candidates), transcript.id)) if len(candidates) == 0: return []