Skip to content

Commit

Permalink
Add num workers option for token pooling (#190)
Browse files Browse the repository at this point in the history
* feat: add `num_workers` option for token pooling

* docs: fix docstring for token pooling
  • Loading branch information
tonywu71 authored Feb 11, 2025
1 parent 70614de commit 1b070df
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions colpali_engine/compression/token_pooling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union, cast
from typing import Dict, List, Optional, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -44,14 +44,14 @@ def pool_embeddings(

class HierarchicalTokenPooler(BaseTokenPooler):
"""
Hierarchical token pooling of multi-vector embeddings based on the similarity between tokens.
Hierarchical token pooling of multi-vector embeddings based on the similarity between token embeddings.
"""

def __init__(self, pool_factor: int):
"""
Args:
pool_factor: An integer factor that determines the maximum number of clusters as
max_clusters = max(token_length // pool_factor, 1).
`max_clusters = max(token_length // pool_factor, 1)`.
"""
self.pool_factor = pool_factor

Expand All @@ -62,14 +62,14 @@ def pool_embeddings(
padding: bool = False,
padding_value: float = 0.0,
padding_side: str = "left",
num_workers: Optional[int] = None,
) -> Union[Union[torch.Tensor, List[torch.Tensor]], TokenPoolingOutput]:
"""
Return the pooled embeddings.
Args:
embeddings: A list of 2D tensors (token_length, embedding_dim) where each tensor can have its own
token_length, or a 3D tensor of shape (batch_size, token_length, embedding_dim) with
optional padding.
embeddings: A list of 2D tensors (token_length, embedding_dim) where each tensor can have its own token
length, or a 3D tensor of shape (batch_size, token_length, embedding_dim) with padding.
return_dict: Whether or not to return a `TokenPoolingOutput` object (with the cluster id to token indices
mapping) instead of just the pooled embeddings.
padding: Whether or not to unbind the padded 3D tensor into a list of 2D tensors. Does nothing if the input
Expand All @@ -78,7 +78,13 @@ def pool_embeddings(
padding_side: The side where the padding was applied in the 3D tensor.
Returns:
A list of pooled embeddings or `TokenPoolingOutput` objects.
If the `embeddings` input is:
- A list of 2D tensors: Returns a list of 2D tensors (token_length, embedding_dim) where each tensor can
have its own token_length.
- A 3D tensor: A 3D tensor of shape (batch_size, token_length, embedding_dim) with padding.
If `return_dict` is True, the pooled embeddings are returned within a `TokenPoolingOutput` object, along
with the cluster id to token indices mapping.
"""
if isinstance(embeddings, list) and not embeddings:
return TokenPoolingOutput(pooled_embeddings=[], cluster_id_to_indices=[])
Expand All @@ -99,7 +105,7 @@ def pool_embeddings(
else:
embeddings = list(embeddings.unbind(dim=0))

with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(num_workers) as executor:
# NOTE: We opted for a thread-based pool because most of the heavy lifting is done in C-level libraries
# (NumPy, Torch, and SciPy) which usually release the GIL.
results = list(executor.map(self._pool_single_embedding, embeddings))
Expand Down Expand Up @@ -133,7 +139,8 @@ def _pool_single_embedding(self, embedding: torch.Tensor) -> Tuple[torch.Tensor,
embedding: A tensor of shape (token_length, embedding_dim).
Returns:
A pooled embedding tensor or a `TokenPoolingOutput` object.
pooled_embedding: A tensor of shape (num_clusters, embedding_dim).
cluster_id_to_indices: A dictionary mapping the cluster id to token indices.
"""
if embedding.dim() != 2:
raise ValueError("The input tensor must be a 2D tensor.")
Expand Down

0 comments on commit 1b070df

Please sign in to comment.