Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: default api port and attention selector #634

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions aphrodite/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
is_openvino, is_tpu, is_xpu)
from aphrodite.platforms import current_platform

APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
APHRODITE_ATTENTION_BACKEND = os.getenv("APHRODITE_ATTENTION_BACKEND", None)


class _Backend(enum.Enum):
Expand Down Expand Up @@ -42,7 +42,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
'''
Get the backend override specified by the Aphrodite attention
backend environment variable, if one is specified.

Returns:

* _Backend enum value if an override is specified
* None otherwise
'''
Expand All @@ -64,9 +66,12 @@ def get_env_variable_attn_backend() -> Optional[_Backend]:
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
'''
Force all attention operations to use a specified backend.

Passing `None` for the argument re-enables automatic
backend selection.,

Arguments:

* attn_backend: backend selection (None to revert to auto)
'''
global forced_attn_backend
Expand All @@ -92,15 +97,14 @@ def get_attn_backend(
block_size: int,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""

if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from aphrodite.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
"""Determine which attention backend to use and only import
the selected backend module.
"""

backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
Expand All @@ -120,12 +124,12 @@ def get_attn_backend(
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for CPU devices.")
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO attention backend.")
logger.info("Using OpenVINO Attention backend.")
from aphrodite.attention.backends.openvino import (
OpenVINOAttentionBackend)
return OpenVINOAttentionBackend
Expand Down Expand Up @@ -157,7 +161,6 @@ def which_attn_to_use(
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""

# Default case.
selected_backend = _Backend.FLASH_ATTN

Expand All @@ -174,7 +177,8 @@ def which_attn_to_use(
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = APHRODITE_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var.upper())
selected_backend = backend_name_to_enum(backend_by_env_var)

if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info(f"Cannot use {selected_backend} backend on CPU.")
Expand Down Expand Up @@ -266,9 +270,13 @@ def global_force_attn_backend_context_manager(
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.

Arguments:

* attn_backend: attention backend to force

Returns:

* Generator
'''

Expand Down
3 changes: 2 additions & 1 deletion aphrodite/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ def get_distributed_init_method(ip: str, port: int) -> str:
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = int(os.getenv("APHRODITE_PORT", 2242))
port = int(os.getenv("APHRODITE_PORT", 0)
) if "APHRODITE_PORT" in os.environ else None
if port is not None:
while True:
try:
Expand Down
Loading