Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed return type of multithread_exec to iterator #1019

Merged
merged 3 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doctr/datasets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __next__(self):
idx = self._num_yielded * self.batch_size
indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)]

samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers)
samples = list(multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers))

batch_data = self.collate_fn(samples)

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ def __call__(

elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x):
# Sample transform (to tensor, resize)
samples = multithread_exec(self.sample_transforms, x)
samples = list(multithread_exec(self.sample_transforms, x))
# Batching
batches = self.batch_inputs(samples) # type: ignore[arg-type]
batches = self.batch_inputs(samples)
else:
raise TypeError(f"invalid input type: {type(x)}")

# Batch transforms (normalize)
batches = multithread_exec(self.normalize, batches) # type: ignore[assignment]
batches = list(multithread_exec(self.normalize, batches))

return batches
6 changes: 3 additions & 3 deletions doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def __call__(

elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
# Sample transform (to tensor, resize)
samples = multithread_exec(self.sample_transforms, x)
samples = list(multithread_exec(self.sample_transforms, x))
# Batching
batches = self.batch_inputs(samples) # type: ignore[arg-type]
batches = self.batch_inputs(samples)
else:
raise TypeError(f"invalid input type: {type(x)}")

# Batch transforms (normalize)
batches = multithread_exec(self.normalize, batches) # type: ignore[assignment]
batches = list(multithread_exec(self.normalize, batches))

return batches
10 changes: 6 additions & 4 deletions doctr/utils/multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import multiprocessing as mp
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Iterable, Iterator, Optional

__all__ = ['multithread_exec']


def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterable[Any]:
def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterator[Any]:
"""Execute a given function in parallel for each element of a given sequence

>>> from doctr.utils.multithreading import multithread_exec
Expand All @@ -24,7 +24,7 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op
threads: number of workers to be used for multiprocessing

Returns:
iterable of the function's results using the iterable as inputs
iterator of the function's results using the iterable as inputs
"""

threads = threads if isinstance(threads, int) else min(16, mp.cpu_count())
Expand All @@ -34,5 +34,7 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op
# Multi-threading
else:
with ThreadPool(threads) as tp:
results = tp.map(func, seq) # type: ignore[assignment]
# ThreadPool's map function returns a list, but seq could be of a different type
# That's why wrapping result in map to return iterator
results = map(lambda x: x, tp.map(func, seq))
return results
2 changes: 1 addition & 1 deletion tests/common/test_utils_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
],
)
def test_multithread_exec(input_seq, func, output_seq):
assert multithread_exec(func, input_seq) == output_seq
assert list(multithread_exec(func, input_seq)) == output_seq
assert list(multithread_exec(func, input_seq, 0)) == output_seq