Skip to content

Commit

Permalink
redirect tqdm to logger object
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewThe committed Dec 15, 2023
1 parent e249720 commit 4121adc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
20 changes: 13 additions & 7 deletions job_pool/job_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from tqdm import tqdm

from job_pool.tqdm_logger import TqdmToLogger

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -87,7 +89,7 @@ def __init__(
queue = multiprocessing.Queue()
queue_listener = QueueListener(queue, logger)
queue_listener.start()

self.pool = NestablePool(
processes,
worker_init,
Expand All @@ -104,13 +106,15 @@ def applyAsync(self, f, fargs, *args, **kwargs):
def checkPool(self, printProgressEvery: int = -1):
try:
outputs = list()
for res in tqdm(self.results):
tqdm_out = TqdmToLogger(logger, level=logging.INFO)
for res in tqdm(
self.results,
file=tqdm_out,
miniters=printProgressEvery,
maxinterval=float("inf"),
):
self.checkForTerminatedProcess(res)
outputs.append(res.get())
if printProgressEvery > 0 and len(outputs) % printProgressEvery == 0:
logger.info(
f' {len(outputs)} / {len(self.results)} {"%.2f" % (float(len(outputs)) / len(self.results) * 100)}%'
)
self.pool.close()
self.pool.join()
return outputs
Expand All @@ -132,7 +136,9 @@ def checkForTerminatedProcess(self, res):
start_time = time.time()
while not res.ready():
if proc := self.pool.check_for_terminated_processes():
raise AbnormalWorkerTerminationError(f"Caught abnormal exit of one of the workers: {proc}")
raise AbnormalWorkerTerminationError(
f"Caught abnormal exit of one of the workers: {proc}"
)

# wait for one second before checking exit codes again
res.wait(timeout=1)
Expand Down
26 changes: 26 additions & 0 deletions job_pool/tqdm_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import io
import logging


class TqdmToLogger(io.StringIO):
"""
Output stream for TQDM which will output to logger module instead of
the StdOut.
Copied from: https://github.com/tqdm/tqdm/issues/313
"""

logger = None
level = None
buf = ""

def __init__(self, logger, level=None):
super(TqdmToLogger, self).__init__()
self.logger = logger
self.level = level or logging.INFO

def write(self, buf):
self.buf = buf.strip("\r\n\t ")

def flush(self):
self.logger.log(self.level, self.buf)

0 comments on commit 4121adc

Please sign in to comment.