Skip to content

Commit

Permalink
Using msgpack with disk dumping to ferry in a single go to children p…
Browse files Browse the repository at this point in the history
…rocesses (#280)
  • Loading branch information
lucventurini committed Mar 28, 2020
1 parent 7d5e843 commit 0f37167
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions Mikado/serializers/blast_serializer/tabular_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def __init__(self,
index_file: str,
identifier: int,
params_file: str,
mapname: str,
lock: mp.RLock,
conf: dict,
maxobjects: int,
Expand All @@ -288,7 +287,7 @@ def __init__(self,
self.lock = lock
self.maxobjects = maxobjects
self.conf = conf
self.params_file, self.mapname, self.index_file = params_file, mapname, index_file
self.params_file, self.index_file = params_file, index_file
if logging_queue is None:
self.logger = create_null_logger("preparer-{}".format(self.identifier)) # create_null_logger
self.logging_queue = None
Expand All @@ -298,12 +297,10 @@ def __init__(self,
def run(self):
with open(self.params_file, "rb") as pfile:
params = msgpack.loads(pfile.read(), raw=False, strict_map_key=False)
dtype = np.dtype(params["dtype"])
shape = tuple(params["shape"])
self.columns = params["columns"]
self.values = np.memmap(self.mapname, dtype=dtype, mode="r", shape=shape)
with open(self.index_file, "rb") as index_handle:
self.indexes = msgpack.loads(index_handle.read(), raw=False, strict_map_key=False)
os.remove(self.index_file) # Clean it up
prep_hit = partial(prepare_tab_hit,
columns=self.columns, qmult=self.qmult, tmult=self.tmult,
matrix_name=self.matrix_name)
Expand All @@ -319,12 +316,7 @@ def run(self):
session = Session(bind=self.engine)
self.session = session
hits, hsps = [], []
for index in self.indexes:
try:
key, group = index
except ValueError:
raise ValueError(index)
rows = self.values[group, :]
for key, rows in self.indexes:
curr_hit, curr_hsps = prep_hit(key, rows)
hits.append(curr_hit)
hsps += curr_hsps
Expand Down Expand Up @@ -352,13 +344,12 @@ def parse_tab_blast(self,
lock = mp.RLock()
conf = dict()
conf["db_settings"] = self.json_conf["db_settings"].copy()
mapped_name = tempfile.mktemp(suffix=".mmap")
params_file = tempfile.mktemp(suffix=".mgp")

index_files = dict((idx, tempfile.mktemp(suffix=".csv")) for idx in
range(procs))
kwargs = {"conf": conf,
"maxobjects": int(self.maxobjects),
"mapname": mapped_name,
"lock": lock,
"matrix_name": matrix_name,
"qmult": qmult,
Expand Down Expand Up @@ -392,24 +383,21 @@ def parse_tab_blast(self,
assert len(hsps) >= len(hits), (len(hits), len(hsps), hits[0], hsps[0])
else:
self.logger.info("Finished reading %s data, starting serialisation with %d processors", bname, procs)
mapped = np.memmap(mapped_name, shape=values.shape, dtype=values.dtype, mode="w+")
mapped[:] = values[:] # Copy the data in the memory view
# Now we have to write down everything inside the temporary files.
# Params_name must contain: shape of the array, dtype of the array, columns
params = {"shape": values.shape, "dtype": str(values.dtype), "columns": columns}
params = {"columns": columns}
with open(params_file, "wb") as pfile:
pfile.write(msgpack.dumps(params))
assert os.path.exists(params_file)
# Split the indices
for idx, split in enumerate(np.array_split(np.array(list(groups.items())), procs)):
with open(index_files[idx], "wb") as index:
index.write(msgpack.dumps(split.tolist()))
vals = [(tuple(item[0]), values[item[1], :].tolist()) for item in split]
index.write(msgpack.dumps(vals))
assert os.path.exists(index_files[idx])
processes[idx].start()

[proc.start() for proc in processes]
[proc.join() for proc in processes]
os.remove(mapped_name)
os.remove(params_file)
[os.remove(index_files[idx]) for idx in index_files]

return

0 comments on commit 0f37167

Please sign in to comment.