Skip to content

Commit

Permalink
perf(call): lower mem with specific importing + early deleting
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Sep 28, 2024
1 parent 942737e commit ff43b90
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
18 changes: 14 additions & 4 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,10 @@ def get_read_length_partition_mean(p_idx: int) -> float:
read_q_coords[rn] = q_coords
read_r_coords[rn] = r_coords

# Manually clean up large numpy arrays after we're done stashing them in read_q_coords/read_r_coords
del q_coords
del r_coords

# End of first read loop -------------------------------------------------------------------------------------------

n_reads_in_dict: int = len(read_dict)
Expand Down Expand Up @@ -1439,14 +1443,13 @@ def get_read_length_partition_mean(p_idx: int) -> float:
else:
# LIMITATION: Currently can only use SNVs for haplotyping with haploid/diploid

# Second read loop occurs in this function
ref_cache: str = ref.fetch(contig, left_most_coord, right_most_coord + 1).upper()

useful_snvs: list[tuple[int, int]] = process_read_snvs_for_locus_and_calculate_useful_snvs(
left_coord_adj,
right_coord_adj,
left_most_coord,
ref_cache,
# Reference sequence - don't assign to a variable to avoid keeping a large amount of data around until
# the GC arises from slumber.
ref.fetch(contig, left_most_coord, right_most_coord + 1).upper(),
read_dict_extra,
read_q_coords,
read_r_coords,
Expand Down Expand Up @@ -1489,6 +1492,10 @@ def get_read_length_partition_mean(p_idx: int) -> float:
call_data = call_res[0] # Call data dictionary
locus_result["snvs"] = call_res[1] # Called useful SNVs

# We're done with read_q_coords and read_r_coords - free them early
del read_q_coords
del read_r_coords

single_or_dist_assign: bool = assign_method in ("single", "dist")

if single_or_dist_assign: # Didn't use SNVs, so call the 'old-fashioned' way - using only copy number
Expand Down Expand Up @@ -1611,6 +1618,9 @@ def _consensi_for_key(k: Literal["_tr_seq", "_start_anchor"]):
call_seqs.extend(_consensi_for_key("_tr_seq"))
call_anchor_seqs.extend(_consensi_for_key("_start_anchor"))

# We're done with read dict extra, delete early
del read_dict_extra

peak_data = {
"means": call_peaks,
"weights": call_weights,
Expand Down
44 changes: 22 additions & 22 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from __future__ import annotations

import heapq
import logging
import multiprocessing as mp
import multiprocessing.dummy as mpd
import multiprocessing.managers as mmg
import numpy as np
import os
import pysam
import queue
import re
import threading
import time
import traceback

from operator import itemgetter
from heapq import merge as heapq_merge
from multiprocessing.synchronize import Event as EventClass # For type hinting
from numpy.random import Generator as NPRandomGenerator, default_rng as np_default_rng
from operator import itemgetter
from pysam import VariantFile as PySamVariantFile
from queue import Empty as QueueEmpty
from threading import Lock
from typing import Iterable, Literal, Optional

from .allele import get_n_alleles
Expand Down Expand Up @@ -63,13 +61,13 @@ def locus_worker(
worker_id: int,
params: CallParams,
locus_queue: mp.Queue,
locus_counter_lock: threading.Lock,
locus_counter_lock: Lock,
locus_counter: mmg.ValueProxy,
phase_set_lock: threading.Lock,
phase_set_lock: Lock,
phase_set_counter: mmg.ValueProxy,
phase_set_remap: mmg.DictProxy,
phase_set_synonymous: mmg.DictProxy,
snv_genotype_update_lock: threading.Lock,
snv_genotype_update_lock: Lock,
snv_genotype_cache: mmg.DictProxy,
is_single_processed: bool,
) -> list[LocusResult]:
Expand All @@ -80,7 +78,8 @@ def locus_worker(
else:
pr = None

import pysam as p
from os import getpid
from pysam import FastaFile, VariantFile
from strkit_rust_ext import STRkitBAMReader, STRkitVCFReader

lg: logging.Logger
Expand All @@ -89,16 +88,16 @@ def locus_worker(
lg = get_main_logger()
else:
from strkit.logger import create_process_logger
lg = create_process_logger(os.getpid(), params.log_level)
lg = create_process_logger(getpid(), params.log_level)

sample_id = params.sample_id

ref = p.FastaFile(params.reference_file)
ref = FastaFile(params.reference_file)
bf = STRkitBAMReader(params.read_file, params.reference_file)

snv_vcf_contigs: list[str] = []
if params.snv_vcf:
with p.VariantFile(params.snv_vcf) as snv_vcf_file:
with VariantFile(params.snv_vcf) as snv_vcf_file:
snv_vcf_contigs.extend(map(lambda c: c.name, snv_vcf_file.header.contigs.values()))

vcf_file_format: Literal["chr", "num", "acc", ""] = get_vcf_contig_format(snv_vcf_contigs)
Expand All @@ -117,7 +116,7 @@ def locus_worker(
if td is None: # Kill signal
lg.debug(f"worker %d finished current contig: %s", worker_id, current_contig)
break
except queue.Empty:
except QueueEmpty:
lg.debug(f"worker %d encountered queue.Empty", worker_id)
break

Expand Down Expand Up @@ -165,6 +164,7 @@ def locus_worker(
)

except Exception as e:
import traceback
res = None
lg.error(f"{locus_log_str} - encountered exception while genotyping ({t_idx=}, {n_alleles=}): {repr(e)}")
lg.error(f"{locus_log_str} - {traceback.format_exc()}")
Expand Down Expand Up @@ -199,14 +199,14 @@ def progress_worker(
num_loci: int,
event: EventClass,
):
import os
from os import nice as os_nice, getpid
try:
os.nice(20)
os_nice(20)
except (AttributeError, OSError):
pass

from strkit.logger import create_process_logger
lg = create_process_logger(os.getpid(), log_level)
lg = create_process_logger(getpid(), log_level)

def _log():
try:
Expand Down Expand Up @@ -275,7 +275,7 @@ def call_sample(
f"HP={params.use_hp}, SNVs={params.snv_vcf is not None}; seed={params.seed}")

# Seed the random number generator if a seed is provided, for replicability
rng: np.random.Generator = np.random.default_rng(seed=params.seed)
rng: NPRandomGenerator = np_default_rng(seed=params.seed)

manager: mmg.SyncManager = mp.Manager()
locus_queue = manager.Queue() # TODO: one queue per contig?
Expand Down Expand Up @@ -328,10 +328,10 @@ def call_sample(

# If we're outputting a VCF, open the file and write the header
sample_id_str = params.sample_id or "sample"
vf: Optional[pysam.VariantFile] = None
vf: Optional[PySamVariantFile] = None
if vcf_path is not None:
vh = build_vcf_header(sample_id_str, params.reference_file)
vf = pysam.VariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)
vf = PySamVariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)

# ---

Expand Down

0 comments on commit ff43b90

Please sign in to comment.