Skip to content

Commit

Permalink
Multithread batch sampler for PatchInferer (#6139)
Browse files Browse the repository at this point in the history
Fixes #6138 

### Description

This PR makes the batch sampling of `PatchInferer` multi-threaded so
that WSI inference (using WSI splitters) can leverage this to enhance
performance.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.

---------

Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
drbeh authored Jul 15, 2023
1 parent 3b6b11a commit dc1bc77
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
16 changes: 15 additions & 1 deletion monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from monai.apps.utils import get_logger
from monai.data.meta_tensor import MetaTensor
from monai.data.thread_buffer import ThreadBuffer
from monai.inferers.merger import AvgMerger, Merger
from monai.inferers.splitter import Splitter
from monai.inferers.utils import compute_importance_map, sliding_window_inference
Expand Down Expand Up @@ -103,6 +104,7 @@ class PatchInferer(Inferer):
the output dictionary to be used for merging.
Defaults to None, where all the keys are used.
match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0.
merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
`merged_shape` is calculated automatically based on the input shape and
the output patch shape unless it is passed here.
Expand All @@ -117,6 +119,7 @@ def __init__(
postprocessing: Callable | None = None,
output_keys: Sequence | None = None,
match_spatial_shape: bool = True,
buffer_size: int = 0,
**merger_kwargs: Any,
) -> None:
Inferer.__init__(self)
Expand Down Expand Up @@ -157,6 +160,8 @@ def __init__(
self.postprocessing = postprocessing

# batch size for patches
if batch_size < 1:
raise ValueError(f"`batch_size` must be a positive number, {batch_size} is given.")
self.batch_size = batch_size

# model output keys
Expand All @@ -165,6 +170,9 @@ def __init__(
# whether to crop the output to match the input shape
self.match_spatial_shape = match_spatial_shape

# buffer size for multithreaded batch sampling
self.buffer_size = buffer_size

def _batch_sampler(
self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
) -> Iterator[tuple[torch.Tensor, Sequence, int]]:
Expand All @@ -182,10 +190,16 @@ def _batch_sampler(
batch_size = min(self.batch_size, total_size - i)
yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size # type: ignore
else:
buffer: Iterable | ThreadBuffer
if self.buffer_size > 0:
# Use multi-threading to sample patches with a buffer
buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.1)
else:
buffer = patches
patch_batch: list[Any] = [None] * self.batch_size
location_batch: list[Any] = [None] * self.batch_size
idx_in_batch = 0
for sample in patches:
for sample in buffer:
patch_batch[idx_in_batch] = sample[0]
location_batch[idx_in_batch] = sample[1]
idx_in_batch += 1
Expand Down
20 changes: 20 additions & 0 deletions tests/test_patch_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
TENSOR_4x4,
]


# non-divisible patch_size leading to larger image (without matching spatial shape)
TEST_CASE_11_PADDING = [
TENSOR_4x4,
Expand Down Expand Up @@ -155,6 +156,23 @@
TENSOR_4x4,
]

# multi-threading
TEST_CASE_14_MULTITHREAD_BUFFER = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),
lambda x: x,
TENSOR_4x4,
]

# multi-threading with batch
TEST_CASE_15_MULTITHREADD_BUFFER = [
TENSOR_4x4,
dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),
lambda x: x,
TENSOR_4x4,
]


# list of tensor output
TEST_CASE_0_LIST_TENSOR = [
TENSOR_4x4,
Expand Down Expand Up @@ -245,6 +263,8 @@ class PatchInfererTests(unittest.TestCase):
TEST_CASE_11_PADDING,
TEST_CASE_12_MATCHING,
TEST_CASE_13_PADDING_MATCHING,
TEST_CASE_14_MULTITHREAD_BUFFER,
TEST_CASE_15_MULTITHREADD_BUFFER,
]
)
def test_patch_inferer_tensor(self, inputs, arguments, network, expected):
Expand Down

0 comments on commit dc1bc77

Please sign in to comment.