diff --git a/Mikado/loci/superlocus.py b/Mikado/loci/superlocus.py index 39d765dea..0e80f10d7 100644 --- a/Mikado/loci/superlocus.py +++ b/Mikado/loci/superlocus.py @@ -81,7 +81,7 @@ def __init__(self, """ :param transcript_instance: an instance of the Transcript class - :type transcript_instance: Transcript + :type transcript_instance: [Transcript|None] :param stranded: boolean flag that indicates whether the Locus should use or ignore strand information :type stranded: bool @@ -541,8 +541,6 @@ def _create_data_dict(self, engine, tid_keys): score = int(ext.score) elif ext.rtype == "float": score = float(ext.score) - elif ext.rtype == "complex": - score = complex(ext.score) elif ext.rtype == "bool": score = bool(int(ext.score)) else: @@ -1348,7 +1346,7 @@ def is_intersecting(cls, transcript, other, cds_only=False): transcript.finalize() other.finalize() - if transcript.id == other.id: + if transcript == other: return False # We do not want intersection with oneself if transcript.monoexonic is False and other.monoexonic is False: diff --git a/Mikado/tests/locus_tester.py b/Mikado/tests/locus_tester.py index 3ee7d57ab..2716fa37e 100644 --- a/Mikado/tests/locus_tester.py +++ b/Mikado/tests/locus_tester.py @@ -1299,6 +1299,93 @@ def test_remove_AS_overlapping(self): self.assertEqual(len(superlocus.loci[locus_two_id].transcripts), 1) +class EmptySuperlocus(unittest.TestCase): + + def test_empty(self): + + logger = create_null_logger() + logger.setLevel("WARNING") + with self.assertLogs(logger, level="WARNING"): + _ = Superlocus(transcript_instance=None) + + +class WrongSplitting(unittest.TestCase): + + def test_split(self): + + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + t2 = Transcript(BED12("Chr1\t100\t1000\tID=t2;coding=False\t0\t-\t100\t1000\t0\t1\t900\t0")) + sl = Superlocus(t1, stranded=False) + sl.add_transcript_to_locus(t2) + splitted = list(sl.split_strands()) + self.assertEqual(len(splitted), 2) + self.assertIsInstance(splitted[0], Superlocus) + self.assertIsInstance(splitted[1], Superlocus) + self.assertTrue(splitted[0].stranded) + self.assertTrue(splitted[1].stranded) + + def test_invalid_split(self): + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + t2 = Transcript(BED12("Chr1\t100\t1000\tID=t2;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + + logger = create_default_logger("test_invalid_split", level="WARNING") + with self.assertLogs(logger=logger, level="WARNING") as cm: + sl = Superlocus(t1, stranded=True, logger=logger) + sl.add_transcript_to_locus(t2) + splitted = list(sl.split_strands()) + + self.assertEqual(splitted[0], sl) + self.assertEqual(len(splitted), 1) + self.assertIn("WARNING:test_invalid_split:Trying to split by strand a stranded Locus, {}!".format(sl.id), + cm.output, cm.output) + + +class WrongLoadingAndIntersecting(unittest.TestCase): + + def test_wrong_loading(self): + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + sl = Superlocus(t1, stranded=True) + with self.assertRaises(ValueError): + sl.load_all_transcript_data(engine=None, data_dict=None) + + @unittest.skip + def test_already_loaded(self): + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + sl = Superlocus(t1, stranded=True) + + def test_wrong_intersecting(self): + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t+\t100\t1000\t0\t1\t900\t0")) + sl = Superlocus(t1, stranded=True) + + with self.subTest(): + self.assertFalse(sl.is_intersecting(t1, t1)) + t2 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=False\t0\t-\t100\t1000\t0\t1\t900\t0")) + + with self.subTest(): + self.assertTrue(sl.is_intersecting(t1, t2)) + + def test_coding_intersecting(self): + t1 = Transcript(BED12("Chr1\t100\t1000\tID=t1;coding=True\t0\t+\t200\t500\t0\t1\t900\t0")) + sl = Superlocus(t1, stranded=True) + t2 = Transcript(BED12("Chr1\t100\t1000\tID=t2;coding=True\t0\t+\t600\t900\t0\t1\t900\t0")) + t3 = Transcript(BED12("Chr1\t100\t1000\tID=t3;coding=True\t0\t+\t300\t600\t0\t1\t900\t0")) + t1.finalize() + t2.finalize() + t3.finalize() + self.assertTrue(t1.is_coding) + self.assertTrue(t2.is_coding) + self.assertTrue(t3.is_coding) + self.assertNotEqual(t1, t2) + self.assertNotEqual(t1, t3) + + with self.subTest(): + self.assertTrue(sl.is_intersecting(t1, t2, cds_only=False)) + self.assertFalse(sl.is_intersecting(t1, t2, cds_only=True)) + with self.subTest(): + self.assertTrue(sl.is_intersecting(t1, t3, cds_only=False)) + self.assertTrue(sl.is_intersecting(t1, t3, cds_only=True)) + + class RetainedIntronTester(unittest.TestCase): def setUp(self): diff --git a/Mikado/transcripts/transcript.py b/Mikado/transcripts/transcript.py index 6b8b28a7f..2d2534808 100644 --- a/Mikado/transcripts/transcript.py +++ b/Mikado/transcripts/transcript.py @@ -489,16 +489,16 @@ def __eq__(self, other) -> bool: if not isinstance(self, type(other)): return False - # self.finalize() - # other.finalize() - if self.strand == other.strand and self.chrom == other.chrom: - if other.start == self.start: - if self.end == other.end: - if self.exons == other.exons: - return True - - return False + return all([ + self.strand == other.strand, + self.chrom == other.chrom, + self.start == other.start, + self.end == other.end, + self.exons == other.exons, + self.combined_cds == other.combined_cds, + self.internal_orfs == other.internal_orfs + ]) def __hash__(self): """Returns the hash of the object (call to super().__hash__()).