Skip to content

Commit

Permalink
fix(call): properly use process worker, log more info about worker ID
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jun 10, 2024
1 parent 7631c37 commit a962931
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
1 change: 0 additions & 1 deletion strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from .align_matrix import match_score
from .cigar import decode_cigar_np
# from .consensus import best_representative
from .consensus import consensus_seq
from .params import CallParams
from .realign import realign_read
Expand Down
32 changes: 22 additions & 10 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from multiprocessing.synchronize import Event as EventClass # For type hinting
from typing import Literal, Optional

from strkit.logger import logger
from .allele import get_n_alleles
from .call_locus import call_locus
from .non_daemonic_pool import NonDaemonicPool
Expand Down Expand Up @@ -57,6 +56,7 @@ def get_vcf_contig_format(snv_vcf_contigs: list[str]) -> Literal["chr", "num", "


def locus_worker(
worker_id: int,
params: CallParams,
locus_queue: mp.Queue,
locus_counter_lock: threading.Lock,
Expand All @@ -81,7 +81,8 @@ def locus_worker(

lg: logging.Logger
if is_single_processed:
lg = logger
from strkit.logger import get_main_logger
lg = get_main_logger()
else:
from strkit.logger import create_process_logger
lg = create_process_logger(os.getpid(), params.log_level)
Expand All @@ -103,24 +104,30 @@ def locus_worker(

snv_vcf_reader = STRkitVCFReader(str(params.snv_vcf)) if params.snv_vcf else None

current_contig: Optional[str] = None
results: list[dict] = []

while True:
try:
td = locus_queue.get_nowait()
if td is None: # Kill signal
logger.debug("worker finished current contig")
lg.debug(f"worker {worker_id} finished current contig: {current_contig}")
break
except queue.Empty:
logger.debug("encountered queue.Empty")
lg.debug(f"worker {worker_id} encountered queue.Empty")
break

t_idx, t, n_alleles, locus_seed = td

if current_contig is None:
current_contig = t[0]

# String representation of locus for logging purposes
locus_log_str: str = f"{sample_id or ''}{' ' if sample_id else ''}locus {t_idx}: {t[0]}:{t[1]}-{t[2]}"
locus_log_str: str = (
f"[w{worker_id}] {sample_id or ''}{' ' if sample_id else ''}locus {t_idx}: {t[0]}:{t[1]}-{t[2]}"
)

logger.debug(f"{locus_log_str} - working on locus")
lg.debug(f"{locus_log_str} - working on locus")

try:
res = call_locus(
Expand All @@ -143,10 +150,10 @@ def locus_worker(

except Exception as e:
res = None
logger.error(
lg.error(
f"{locus_log_str} - encountered exception while genotyping ({t_idx=}, {t[:3]=}, {n_alleles=}): "
f"{repr(e)}")
logger.error(f"{locus_log_str} - {traceback.format_exc()}")
lg.error(f"{locus_log_str} - {traceback.format_exc()}")

locus_counter_lock.acquire(timeout=30)
locus_counter.set(locus_counter.get() + 1)
Expand All @@ -163,6 +170,8 @@ def locus_worker(
pr.disable()
pr.print_stats("tottime")

lg.debug(f"worker {worker_id} - returning batch of {len(results)} locus results")

# Sort worker results; we will merge them after
return results if is_single_processed else sorted(results, key=get_locus_index)

Expand Down Expand Up @@ -236,6 +245,9 @@ def call_sample(
indent_json: bool = False,
output_tsv: bool = True,
) -> None:
from strkit.logger import get_main_logger
logger = get_main_logger()

# Start the call timer
start_time = datetime.now()

Expand All @@ -254,7 +266,7 @@ def call_sample(
vf = pysam.VariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)

manager: mmg.SyncManager = mp.Manager()
locus_queue = manager.Queue()
locus_queue = manager.Queue() # TODO: one queue per contig?

# Add all loci from the BED file to the queue, allowing each job
# to pull from the queue as it becomes freed up to do so.
Expand Down Expand Up @@ -357,7 +369,7 @@ def call_sample(

qsize: int = locus_queue.qsize()
while qsize > 0:
jobs = [p.apply_async(locus_worker, job_args) for _ in range(params.processes)]
jobs = [p.apply_async(locus_worker, (i + 1, *job_args)) for i in range(params.processes)]

# Write results
# - gather the process-specific results for combining
Expand Down
7 changes: 5 additions & 2 deletions strkit/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import strkit.constants as c
from strkit import __version__
from strkit.exceptions import ParamError, InputError
from strkit.logger import logger, attach_stream_handler, log_levels
from strkit.logger import get_main_logger, attach_stream_handler, log_levels


def add_call_parser_args(call_parser):
Expand Down Expand Up @@ -354,6 +354,7 @@ def add_vs_parser_args(vs_parser):

def _exec_call(p_args) -> None:
from strkit.call import call_sample, CallParams
logger = get_main_logger()
call_sample(
CallParams.from_args(logger, p_args),
json_path=p_args.json,
Expand Down Expand Up @@ -571,8 +572,10 @@ def _make_subparser(arg: str, help_text: str, exec_func: Callable, arg_func: Cal
args = args or sys.argv[1:]
p_args = parser.parse_args(args)

logger = get_main_logger()

if hasattr(p_args, "log_level"):
attach_stream_handler(log_levels[p_args.log_level])
attach_stream_handler(log_levels[p_args.log_level], logger)

if not getattr(p_args, "func", None):
p_args = parser.parse_args(("--help",))
Expand Down
11 changes: 7 additions & 4 deletions strkit/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import sys

__all__ = [
"logger",
"get_main_logger",
"attach_stream_handler",
"create_process_logger",
"log_levels",
]

fmt = logging.Formatter(fmt="%(name)s:\t[%(levelname)s]\t%(message)s")

logger = logging.getLogger("strkit-main")
logger.setLevel(logging.DEBUG)

def get_main_logger():
logger = logging.getLogger("strkit-main")
logger.setLevel(logging.DEBUG)
return logger

def attach_stream_handler(level: int, logger_=logger):

def attach_stream_handler(level: int, logger_=None):
ch = logging.StreamHandler(sys.stderr)
ch.setLevel(level)
ch.setFormatter(fmt)
Expand Down

0 comments on commit a962931

Please sign in to comment.