mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 16:50:55 +08:00
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from functools import cache
|
|
from typing import NamedTuple, cast, get_args
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
|
from vllm.attention.backends.registry import (
|
|
MAMBA_TYPE_TO_BACKEND_MAP,
|
|
MambaAttentionBackendEnum,
|
|
)
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class AttentionSelectorConfig(NamedTuple):
|
|
head_size: int
|
|
dtype: torch.dtype
|
|
kv_cache_dtype: CacheDType | None
|
|
block_size: int | None
|
|
use_mla: bool = False
|
|
has_sink: bool = False
|
|
use_sparse: bool = False
|
|
use_mm_prefix: bool = False
|
|
attn_type: str = AttentionType.DECODER
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"AttentionSelectorConfig(head_size={self.head_size}, "
|
|
f"dtype={self.dtype}, "
|
|
f"kv_cache_dtype={self.kv_cache_dtype}, "
|
|
f"block_size={self.block_size}, "
|
|
f"use_mla={self.use_mla}, "
|
|
f"has_sink={self.has_sink}, "
|
|
f"use_sparse={self.use_sparse}, "
|
|
f"use_mm_prefix={self.use_mm_prefix}, "
|
|
f"attn_type={self.attn_type})"
|
|
)
|
|
|
|
|
|
def get_attn_backend(
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: str | None,
|
|
block_size: int | None,
|
|
use_mla: bool = False,
|
|
has_sink: bool = False,
|
|
use_sparse: bool = False,
|
|
use_mm_prefix: bool = False,
|
|
attn_type: str | None = None,
|
|
) -> type[AttentionBackend]:
|
|
"""Selects which attention backend to use and lazily imports it."""
|
|
|
|
if kv_cache_dtype is not None:
|
|
valid_cache_dtypes = get_args(CacheDType)
|
|
assert kv_cache_dtype in valid_cache_dtypes, (
|
|
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
|
|
f"Valid values are: {valid_cache_dtypes}"
|
|
)
|
|
|
|
from vllm.config import get_current_vllm_config
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
backend_enum = vllm_config.attention_config.backend
|
|
|
|
attn_selector_config = AttentionSelectorConfig(
|
|
head_size=head_size,
|
|
dtype=dtype,
|
|
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
|
|
block_size=block_size,
|
|
use_mla=use_mla,
|
|
has_sink=has_sink,
|
|
use_sparse=use_sparse,
|
|
use_mm_prefix=use_mm_prefix,
|
|
attn_type=attn_type or AttentionType.DECODER,
|
|
)
|
|
|
|
return _cached_get_attn_backend(
|
|
backend=backend_enum,
|
|
attn_selector_config=attn_selector_config,
|
|
)
|
|
|
|
|
|
@cache
|
|
def _cached_get_attn_backend(
|
|
backend,
|
|
attn_selector_config: AttentionSelectorConfig,
|
|
) -> type[AttentionBackend]:
|
|
from vllm.platforms import current_platform
|
|
|
|
attention_cls = current_platform.get_attn_backend_cls(
|
|
backend,
|
|
attn_selector_config=attn_selector_config,
|
|
)
|
|
if not attention_cls:
|
|
raise ValueError(
|
|
f"Invalid attention backend for {current_platform.device_name}"
|
|
)
|
|
backend = resolve_obj_by_qualname(attention_cls)
|
|
|
|
# Adjust kv cache layout if the selected backend requires a specific one
|
|
required_layout = backend.get_required_kv_cache_layout()
|
|
if required_layout is not None:
|
|
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
|
|
|
set_kv_cache_layout(required_layout)
|
|
logger.info(
|
|
"Using %s KV cache layout for %s backend.",
|
|
required_layout,
|
|
backend.get_name(),
|
|
)
|
|
|
|
return backend
|
|
|
|
|
|
def get_mamba_attn_backend(
|
|
mamba_type: str,
|
|
) -> type[AttentionBackend]:
|
|
"""Select which mamba attention backend to use and lazily import it."""
|
|
return _cached_get_mamba_attn_backend(mamba_type)
|
|
|
|
|
|
@cache
|
|
def _cached_get_mamba_attn_backend(
|
|
mamba_type: str,
|
|
) -> type[AttentionBackend]:
|
|
assert mamba_type and isinstance(mamba_type, str)
|
|
|
|
selected_backend = None
|
|
try:
|
|
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
|
|
selected_backend = MambaAttentionBackendEnum[backend_name]
|
|
except KeyError as e:
|
|
raise ValueError(
|
|
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
|
|
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
|
|
) from e
|
|
|
|
mamba_attn_backend = selected_backend.get_class()
|
|
return mamba_attn_backend
|