Skip to content

Commit

Permalink
Now setting 'spawn' as the default multiprocessing method *for the wh…
Browse files Browse the repository at this point in the history
…ole of Mikado*. Lots of amendments to allow it to function properly. For \#280.
  • Loading branch information
lucventurini committed Mar 30, 2020
1 parent 5d94e5f commit 8bf584b
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 172 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ install:
- conda install -c bioconda -c conda-forge -y stringtie scallop gmap star hisat2 prodigal blast diamond transdecoder gnuplot kallisto samtools gffread
- python setup.py develop;
script:
- python setup.py test --addopts "-m slow Mikado/tests/test_system_calls.py::SerialiseChecker::test_subprocess_multi_empty_orfs"
- python setup.py test --addopts " --cov Mikado --cov-config=.coveragerc -m '(slow or not slow) and not triage'";
- pytest -m slow Mikado/tests/test_system_calls.py::SerialiseChecker::test_subprocess_multi_empty_orfs
- pytest --cov Mikado --cov-config=.coveragerc -m '(slow or not slow) and not triage';
- cd sample_data; snakemake --jobs 5 --cores 5
- cd ..;
- python -c "import Mikado; Mikado.test(label='fast')";
Expand Down
2 changes: 1 addition & 1 deletion Mikado/configuration/configuration_blueprint.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
"multiprocessing_method": {
"type": "string",
"default": "",
"default": "spawn",
"enum": ["fork", "spawn", "forkserver", ""]
},
"log_settings": {
Expand Down
38 changes: 32 additions & 6 deletions Mikado/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ class Parser(metaclass=abc.ABCMeta):

def __init__(self, handle):
self.__closed = False
self._handle = self.__get_handle(handle)
self.closed = False

def __iter__(self):
return self

def __get_handle(self, handle, position=None):
if not isinstance(handle, io.IOBase):
if handle.endswith(".gz") or filetype(handle) == b"application/gzip":
opener = gzip.open
Expand All @@ -36,12 +43,9 @@ def __init__(self, handle):
handle = opener(handle, "rt")
except FileNotFoundError:
raise FileNotFoundError("File not found: {0}".format(handle))

self._handle = handle
self.closed = False

def __iter__(self):
return self
if position is not None:
handle.seek(position)
return handle

def __next__(self):
line = self._handle.readline()
Expand Down Expand Up @@ -92,6 +96,28 @@ def closed(self, *args):

self.__closed = args[0]

def __getstate__(self):
try:
position = self._handle.tell()
except:
position = None
state = dict()
state.update(self.__dict__)
state["position"] = position
if hasattr(self._handle, "filename"):
state["_handle"] = self._handle.filename
elif hasattr(self._handle, "name"):
state["_handle"] = self._handle.name
else:
raise TypeError("Unknown handle: {}".format(self._handle))
return state

def __setstate__(self, state):
position = state.get("position")
del state["position"]
self.__dict__.update(state)
self._handle = self.__get_handle(state["_handle"], position=position)


# noinspection PyPep8
from . import GFF
Expand Down
24 changes: 16 additions & 8 deletions Mikado/parsers/bed12.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,17 +1614,11 @@ def __init__(self,
self.rec_queue = rec_queue
self.return_queue = return_queue
self.logging_queue = log_queue
self.handler = logging_handlers.QueueHandler(self.logging_queue)
self.logger = logging.getLogger(self.name)
self.logger.addHandler(self.handler)
self.logger.setLevel(level)
self.logger.propagate = False
self.transcriptomic = transcriptomic
self.__max_regression = 0
self._max_regression = max_regression
self.coding = coding
self.start_adjustment = start_adjustment
# self.cache = cache

if isinstance(fasta_index, dict):
# check that this is a bona fide dictionary ...
Expand All @@ -1640,7 +1634,7 @@ def __init__(self,
fasta_index = pyfaidx.Fasta(fasta_index)
else:
assert isinstance(fasta_index, pysam.FastaFile), type(fasta_index)

self._level = level
self.fasta_index = fasta_index
self.__closed = False
self.header = False
Expand Down Expand Up @@ -1685,7 +1679,21 @@ def gff_next(self, line, sequence):
return bed12

def run(self, *args, **kwargs):
while True:
print("Started", self.__identifier)
self.handler = logging_handlers.QueueHandler(self.logging_queue)
self.logger = logging.getLogger(self.name)
self.logger.addHandler(self.handler)
self.logger.setLevel(self._level)
self.logger.propagate = False

self.logger.info("Started %s", self.__identifier)
if self.rec_queue is None:
self.return_queue.put(b"FINISHED")
raise ValueError
if self.rec_queue.empty():
self.return_queue.put(b"FINISHED")
raise ValueError
while True: # not self.rec_queue.empty():
line = self.rec_queue.get()
if line in ("EXIT", b"EXIT"):
self.rec_queue.put(b"EXIT")
Expand Down
85 changes: 58 additions & 27 deletions Mikado/preparation/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import multiprocessing
import multiprocessing.queues
import os
import sqlalchemy as sqla
import pysam
import msgpack
from Mikado.transcripts.transcriptchecker import TranscriptChecker
from .. import exceptions
from ..loci import Transcript
Expand All @@ -12,6 +14,8 @@
import time
import sys
import numpy
import rapidjson as json
import operator


__author__ = 'Luca Venturini'
Expand Down Expand Up @@ -136,10 +140,11 @@ def create_transcript(lines,
class CheckingProcess(multiprocessing.Process):

def __init__(self,
submission_queue,
batch_file,
logging_queue,
fasta,
identifier,
shelve_stacks,
fasta_out,
gtf_out,
tmpdir,
Expand Down Expand Up @@ -172,11 +177,12 @@ def __init__(self,
create_queue_logger(self)
except AttributeError as exc:
raise AttributeError(exc)
if batch_file is None:
raise ValueError("Invalid Batch file!")
self.__lenient = False
self.lenient = lenient
self.__fasta = fasta
self.__submission_queue = None
self.__set_submission_queue(submission_queue)
self.fasta = pysam.FastaFile(self.__fasta)
self.fasta_out = os.path.join(tmpdir, "{0}-{1}".format(
fasta_out, self.identifier
Expand All @@ -185,9 +191,28 @@ def __init__(self,
gtf_out, self.identifier
))
self.force_keep_cds = force_keep_cds
self.shelve_stacks = shelve_stacks
self.batch_file = batch_file

def _get_stacks(self):
shelve_stacks = dict()
for shelf in self.shelve_stacks:
shelve_stacks[shelf] = dict()
conn_string = "file:{shelf}?mode=ro&nolock=1".format(shelf=shelf)
engine = sqla.engine.create_engine(
"sqlite:///{shelf}".format(shelf=shelf),
connect_args={"check_same_thread": False, "uri": True}
)
shelve_stacks[shelf] = {"conn": engine, "conn_string": conn_string}

return shelve_stacks

def _get_keys(self):
keys = msgpack.loads(open(self.batch_file, "rb").read(), raw=False, strict_map_key=False)
keys = sorted(keys, key=operator.itemgetter(0))
return keys

def run(self):

checker = functools.partial(create_transcript,
# lenient=self.lenient,
force_keep_cds=self.force_keep_cds,
Expand All @@ -202,16 +227,23 @@ def run(self):
self.logger.debug(self.canonical)

__printed = 0
shelve_stacks = self._get_stacks()
file_keys = self._get_keys()

try:
while True:
lines, start, end, counter = self.submission_queue.get()
if lines == "EXIT":
self.logger.debug("Finished for %s", self.name)
self.submission_queue.put((lines,
start,
end,
counter))
break
for key in file_keys:
counter, keys = key
# lines, start, end, counter = self.submission_queue.get()
tid, chrom, (pos) = keys
tid, shelf_name = tid
start, end = pos
item = shelve_stacks[shelf_name]["conn"].execute(
"SELECT features FROM dump WHERE tid = ?", (tid,)).fetchone()

try:
lines = json.loads(item[0])
except TypeError:
raise TypeError(item)
self.logger.debug("Checking %s", lines["tid"])
if "is_reference" not in lines:
raise KeyError(lines)
Expand All @@ -238,7 +270,6 @@ def run(self):
raise KeyboardInterrupt
except Exception as exc:
self.logger.error(exc)
self.submission_queue.close()
self.logging_queue.close()
raise

Expand Down Expand Up @@ -296,20 +327,20 @@ def lenient(self, lenient):
raise ValueError("Invalid lenient value: {}".format(lenient))
self.__lenient = lenient

@property
def submission_queue(self):
return self.__submission_queue

def __set_submission_queue(self, submission):
if isinstance(submission, multiprocessing.queues.SimpleQueue):
if sys.version_info.minor < 6:
raise TypeError("Invalid queue object for Python 3.5 and earlier!")
submission.put_nowait = submission.put
elif not isinstance(submission, (multiprocessing.queues.Queue,
queue.Queue)):
raise TypeError("Invalid queue object: {}".format(type(submission)))
self.__submission_queue = submission

# @property
# def submission_queue(self):
# return self.__submission_queue
#
# def __set_submission_queue(self, submission):
# if isinstance(submission, multiprocessing.queues.SimpleQueue):
# if sys.version_info.minor < 6:
# raise TypeError("Invalid queue object for Python 3.5 and earlier!")
# submission.put_nowait = submission.put
# elif not isinstance(submission, (multiprocessing.queues.Queue,
# queue.Queue)):
# raise TypeError("Invalid queue object: {}".format(type(submission)))
# self.__submission_queue = submission
#
@property
def logging_queue(self):
return self.__logging_queue
Expand Down
64 changes: 37 additions & 27 deletions Mikado/preparation/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging.handlers
import numpy
import functools
import msgpack
import multiprocessing
import multiprocessing.connection
import multiprocessing.sharedctypes
Expand Down Expand Up @@ -200,32 +201,37 @@ def perform_check(keys, shelve_stacks, args, logger):
else:
# pylint: disable=no-member

submission_queue = multiprocessing.JoinableQueue(-1)

working_processes = [CheckingProcess(
submission_queue,
args.logging_queue,
args.json_conf["reference"]["genome"].filename,
_ + 1,
os.path.basename(args.json_conf["prepare"]["files"]["out_fasta"].name),
os.path.basename(args.json_conf["prepare"]["files"]["out"].name),
args.tempdir.name,
seed=args.json_conf["seed"],
lenient=args.json_conf["prepare"]["lenient"],
canonical_splices=args.json_conf["prepare"]["canonical"],
force_keep_cds=not args.json_conf["prepare"]["strip_cds"],
log_level=args.level) for _ in range(args.json_conf["threads"])]

[_.start() for _ in working_processes]

for counter, keys in enumerate(keys):
tid, chrom, (pos) = keys
tid, shelf_name = tid
tobj = json.loads(next(shelve_stacks[shelf_name]["cursor"].execute(
"SELECT features FROM dump WHERE tid = ?", (tid,)))[0])
submission_queue.put((tobj, pos[0], pos[1], counter + 1))

submission_queue.put(tuple(["EXIT"]*4))
# submission_queue = multiprocessing.JoinableQueue(-1)

batches = list(enumerate(keys, 1))
np.random.shuffle(batches)
kwargs = {
"fasta_out": os.path.basename(args.json_conf["prepare"]["files"]["out_fasta"].name),
"gtf_out": os.path.basename(args.json_conf["prepare"]["files"]["out"].name),
"tmpdir": args.tempdir.name,
"seed": args.json_conf["seed"],
"lenient": args.json_conf["prepare"]["lenient"],
"canonical_splices": args.json_conf["prepare"]["canonical"],
"force_keep_cds": not args.json_conf["prepare"]["strip_cds"],
"log_level": args.level
}

working_processes = []
for idx, batch in enumerate(np.array_split(batches, args.json_conf["threads"]), 1):
batch_file = tempfile.NamedTemporaryFile(delete=False, mode="wb")
msgpack.dump(batch.tolist(), batch_file)
batch_file.flush()
batch_file.close()

proc = CheckingProcess(
batch_file.name,
args.logging_queue,
args.json_conf["reference"]["genome"].filename,
idx,
shelve_stacks.keys(),
**kwargs)
proc.start()
working_processes.append(proc)

[_.join() for _ in working_processes]

Expand Down Expand Up @@ -507,8 +513,12 @@ def prepare(args, logger):
for shelf, score, is_reference in zip(shelve_names, shelve_source_scores,
args.json_conf["prepare"]["files"]["reference"]):
assert isinstance(is_reference, bool)
conn = sqlite3.connect(shelf)
conn_string = "file:{shelf}?mode=ro&immutable=1".format(shelf=shelf)
conn = sqlite3.connect(conn_string, uri=True,
isolation_level="DEFERRED",
check_same_thread=False)
shelf_stacks[shelf] = {"conn": conn, "cursor": conn.cursor(), "score": score,
"conn_string": conn_string,
"is_reference": is_reference}
# shelf_stacks = dict((_, shelve.open(_, flag="r")) for _ in shelve_names)
except Exception as exc:
Expand Down
1 change: 1 addition & 0 deletions Mikado/scales/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(self,
load_ref=self.printout_tmap)
self.self_analysis = self.stat_calculator.self_analysis
self.__merged = False

if results:
self.load_from_results(results)
self.__merged = True
Expand Down
Loading

0 comments on commit 8bf584b

Please sign in to comment.