Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Queue refactor #462

Merged
merged 4 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def astart(self):
self._batch_handler = BatchHandler(
max_batch_size=self._engine_args.batch_size,
model_replicas=self._model_replicas,
batch_delay=self._min_inference_t / 2,
# batch_delay=self._min_inference_t / 2,
vector_disk_cache_path=self._engine_args.vector_disk_cache_path,
Comment on lines 88 to 92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Removing batch_delay could lead to aggressive batching and potential resource exhaustion. Consider adding a configurable minimum delay or documenting why this was removed.

verbose=logger.level <= 10,
lengths_via_tokenize=self._engine_args.lengths_via_tokenize,
Expand Down
35 changes: 18 additions & 17 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Sequence, Union, TYPE_CHECKING

import numpy as np

Expand All @@ -33,11 +33,17 @@
ReRankSingle,
get_inner_item,
)
from infinity_emb.transformer.abstract import BaseTransformer

from infinity_emb.transformer.audio.utils import resolve_audios
from infinity_emb.transformer.utils import get_lengths_with_tokenize
from infinity_emb.transformer.vision.utils import resolve_images

if TYPE_CHECKING:
from infinity_emb.transformer.abstract import BaseTypeHint


QUEUE_TIMEOUT = 0.5


class ShutdownReadOnly:
def __init__(self, shutdown: threading.Event) -> None:
Expand All @@ -58,7 +64,7 @@ def submit(self, *args, **kwargs):
class BatchHandler:
def __init__(
self,
model_replicas: list[BaseTransformer],
model_replicas: list["BaseTypeHint"],
max_batch_size: int,
max_queue_wait: int = MANAGER.queue_size,
batch_delay: float = 5e-3,
Expand Down Expand Up @@ -89,7 +95,7 @@ def __init__(
self._shutdown = threading.Event()
self._threadpool = ThreadPoolExecutor()
self._queue_prio = CustomFIFOQueue()
self._result_queue: Queue = Queue(4)
self._result_queue: Queue = Queue(8)
# cache
cache = (
Cache(
Expand Down Expand Up @@ -360,7 +366,7 @@ async def _collect_from_model(
except queue.Empty:
# instead use async await to get
try:
post_batch = await to_thread(result_queue.get, tp, timeout=0.5)
post_batch = await to_thread(result_queue.get, tp, timeout=QUEUE_TIMEOUT)
except queue.Empty:
# in case of timeout start again
continue
Expand Down Expand Up @@ -413,7 +419,7 @@ def __init__(
self,
max_batch_size: int,
shutdown: ShutdownReadOnly,
model: BaseTransformer,
model: "BaseTypeHint",
threadpool: ThreadPoolExecutorReadOnly,
input_q: CustomFIFOQueue,
output_q: Queue,
Expand Down Expand Up @@ -468,12 +474,7 @@ def _preprocess_batch(self):
# decision to attempt to pop a batch
# -> will happen if a single datapoint is available

batches = self._queue_prio.pop_optimal_batches(
self._max_batch_size, latest_first=False
)
if not batches:
# not a single sentence available / len=0, wait for more
continue
batches = self._queue_prio.pop_optimal_batches(self._max_batch_size)
# optimal batch has been selected ->
# lets tokenize it and move tensors to GPU.
for batch in batches:
Expand All @@ -494,7 +495,7 @@ def _preprocess_batch(self):
# while-loop just for shutdown
while not self._shutdown.is_set():
try:
self._feature_queue.put((feat, batch), timeout=0.5)
self._feature_queue.put((feat, batch), timeout=QUEUE_TIMEOUT)
break
except queue.Full:
continue
Expand All @@ -511,7 +512,7 @@ def _core_batch(self):
try:
while not self._shutdown.is_set():
try:
core_batch = self._feature_queue.get(timeout=0.5)
core_batch = self._feature_queue.get(timeout=QUEUE_TIMEOUT)
except queue.Empty:
continue
(feat, batch) = core_batch
Expand All @@ -523,7 +524,7 @@ def _core_batch(self):
# while-loop just for shutdown
while not self._shutdown.is_set():
try:
self._postprocess_queue.put((embed, batch), timeout=0.5)
self._postprocess_queue.put((embed, batch), timeout=QUEUE_TIMEOUT)
break
except queue.Full:
continue
Expand All @@ -537,7 +538,7 @@ def _postprocess_batch(self):
try:
while not self._shutdown.is_set():
try:
post_batch = self._postprocess_queue.get(timeout=0.5)
post_batch = self._postprocess_queue.get(timeout=QUEUE_TIMEOUT)
except queue.Empty:
# instead use async await to get
continue
Expand All @@ -557,7 +558,7 @@ def _postprocess_batch(self):
# while-loop just for shutdown
while not self._shutdown.is_set():
try:
self._output_q.put((results, batch), timeout=0.5)
self._output_q.put((results, batch), timeout=QUEUE_TIMEOUT)
break
except queue.Full:
continue
Expand Down
35 changes: 14 additions & 21 deletions libs/infinity_emb/infinity_emb/inference/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import asyncio
import threading
from typing import Optional, Union
from typing import Optional, Generator

from infinity_emb.inference.caching_layer import Cache
from infinity_emb.primitives import (
Expand Down Expand Up @@ -34,7 +34,7 @@ def extend(self, items: list[PrioritizedQueueItem]):

def pop_optimal_batches(
self, size: int, max_n_batches: int = 4, timeout=0.2, **kwargs
) -> Union[list[list[QueueItemInner]], None]:
) -> Generator[list[QueueItemInner], None, None]:
"""
pop batch `up to size` + `continuous (sorted)` from queue

Expand All @@ -52,35 +52,28 @@ def pop_optimal_batches(
"""
if not self._queue:
if not self._sync_event.wait(timeout):
return None
return

# slice as many batches as possible
n_batches = min(max_n_batches, max(1, len(self._queue) // size))
size_batches = size * n_batches
# Determine the number of batches to process
# n_batches = min(max_n_batches, max(1, len(self._queue) // size))
size_batches = size * max_n_batches

with self._lock_queue_event:
new_items_l = self._queue[:size_batches]
self._queue = self._queue[size_batches:]
if not self._queue:
self._sync_event.clear()

if n_batches > 1:
# sort the sentences by len ->
# optimal padding per batch
if len(new_items_l) > size:
# Sort the items for optimal batching
new_items_l.sort()

new_items: list[list[QueueItemInner]] = []
for i in range(n_batches):
mini_batch = new_items_l[size * i : size * (i + 1)]
mini_batch_e: list[QueueItemInner] = [
mi.item for mi in mini_batch if not mi.item.future.done()
]
if mini_batch_e:
new_items.append(mini_batch_e)
if new_items:
return new_items
else:
return None
new_items: list[QueueItemInner] = [
mi.item for mi in new_items_l if not mi.item.future.done()
]

for i in range(0, len(new_items), size):
yield new_items[i : i + size]


class ResultKVStoreFuture:
Expand Down
5 changes: 5 additions & 0 deletions libs/infinity_emb/infinity_emb/transformer/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def warmup(self, *, batch_size: int = 64, n_tokens=1) -> tuple[float, float, str
return run_warmup(self, inp)


BaseTypeHint = Union[
BaseTransformer, BaseEmbedder, BaseTIMM, BaseAudioEmbedModel, BaseClassifer, BaseCrossEncoder
]


def run_warmup(model, inputs) -> tuple[float, float, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: model parameter should be typed with BaseTypeHint

inputs_formated = [i.content.to_input() for i in inputs]
start = perf_counter()
Expand Down
Loading