Skip to content

Commit

Permalink
[Fix] Postpone cuda import to the calling site (#231)
Browse files Browse the repository at this point in the history
This PR postpones the cuda initialization (torch.cuda.is_available()) to the calling site to avoid the `re-initializing CUDA in forked subprocess` error
  • Loading branch information
Ubospica authored Mar 7, 2025
1 parent ccc355c commit 3018d72
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions python/xgrammar/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@

from .base import XGRObject, _core
from .compiler import CompiledGrammar
from .kernels import apply_token_bitmask_inplace_kernels

"""The dtype of the bitmask: int32."""
bitmask_dtype = torch.int32


_is_cuda_available = torch.cuda.is_available()


def get_bitmask_shape(batch_size: int, vocab_size: int) -> Tuple[int, int]:
"""Return the shape of the bitmask: (batch_size, ceil(vocab_size / 32))."""
return (batch_size, math.ceil(vocab_size / 32))
Expand Down Expand Up @@ -50,12 +46,7 @@ def allocate_token_bitmask(batch_size: int, vocab_size: int) -> torch.Tensor:
The shape of the bitmask.
"""
# In CUDA, use pinned memory to speed up data transfer from CPU to GPU
return torch.full(
get_bitmask_shape(batch_size, vocab_size),
_FULL_MASK,
dtype=bitmask_dtype,
pin_memory=_is_cuda_available,
)
return torch.full(get_bitmask_shape(batch_size, vocab_size), _FULL_MASK, dtype=bitmask_dtype)


def reset_token_bitmask(bitmask: torch.Tensor) -> None:
Expand Down Expand Up @@ -121,6 +112,10 @@ def apply_token_bitmask_inplace(
A list of indices to specify which logits in the batch to apply the bitmask to. Should be
unique. If None, apply the bitmask to all logits in the batch.
"""
# Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
# calling site to avoid re-initializing CUDA in forked subprocess.
from .kernels import apply_token_bitmask_inplace_kernels

# dispatch to different implementations based on the device of logits and bitmask
if bitmask.device != logits.device:
raise ValueError(
Expand Down

0 comments on commit 3018d72

Please sign in to comment.