mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:45:01 +08:00
[Misc] Refactor get_kv_cache_spec into AttentionLayerBase (#26587)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
ab4be40fc5
commit
b26b70bec4
@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
|||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (
|
from vllm.distributed.kv_transfer import (
|
||||||
get_kv_transfer_group,
|
get_kv_transfer_group,
|
||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import (
|
||||||
|
direct_register_custom_op,
|
||||||
|
kv_cache_dtype_str_to_dtype,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
FullAttentionSpec,
|
||||||
|
KVCacheSpec,
|
||||||
|
MLAAttentionSpec,
|
||||||
|
SlidingWindowSpec,
|
||||||
|
)
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
else:
|
else:
|
||||||
sliding_window = None
|
sliding_window = None
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
kv_cache_dtype = cache_config.cache_dtype
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
block_size = cache_config.block_size
|
block_size = cache_config.block_size
|
||||||
@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
kv_cache_dtype = "auto"
|
kv_cache_dtype = "auto"
|
||||||
block_size = 16
|
block_size = 16
|
||||||
calculate_kv_scales = False
|
calculate_kv_scales = False
|
||||||
|
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||||
|
kv_cache_dtype, vllm_config.model_config
|
||||||
|
)
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
assert num_heads % num_kv_heads == 0, (
|
assert num_heads % num_kv_heads == 0, (
|
||||||
@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
self.use_output = self.attn_backend.accept_output_buffer
|
self.use_output = self.attn_backend.accept_output_buffer
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# this variable will not be accessed if use_direct_call is True
|
# this variable will not be accessed if use_direct_call is True
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
torch.tensor([])
|
torch.tensor([])
|
||||||
for _ in range(
|
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Initialize q/k/v range constants.
|
# Initialize q/k/v range constants.
|
||||||
@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
return self.attn_backend
|
return self.attn_backend
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
# Block size may get updated after model loading, refresh it
|
||||||
|
block_size = vllm_config.cache_config.block_size
|
||||||
|
# Should not be called for enc-dec or encoder-only attention.
|
||||||
|
assert self.attn_type == AttentionType.DECODER
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
assert not vllm_config.model_config.use_mla, (
|
||||||
|
"MLA is not supported for slidingwindow"
|
||||||
|
)
|
||||||
|
return SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
"""Multi-headed attention without any cache, used for ViT."""
|
"""Multi-headed attention without any cache, used for ViT."""
|
||||||
@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
return self.attn_backend
|
return self.attn_backend
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||||
|
self.kv_cache_dtype, vllm_config.model_config
|
||||||
|
)
|
||||||
|
return MLAAttentionSpec(
|
||||||
|
block_size=vllm_config.cache_config.block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm import envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
make_local_attention_virtual_batches,
|
make_local_attention_virtual_batches,
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
from ..layer import Attention
|
from ..layer import Attention
|
||||||
|
|
||||||
@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention):
|
|||||||
kv_sharing_target_layer_name: str | None = None,
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
|
self.attention_chunk_size = attention_chunk_size
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
kv_cache_dtype = cache_config.cache_dtype
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention):
|
|||||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
assert self.attention_chunk_size
|
||||||
|
return ChunkedLocalAttentionSpec(
|
||||||
|
block_size=vllm_config.cache_config.block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
attention_chunk_size=self.attention_chunk_size,
|
||||||
|
)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -174,3 +174,11 @@ class CrossAttention(Attention):
|
|||||||
attn_type=AttentionType.ENCODER_DECODER,
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
return CrossAttentionSpec(
|
||||||
|
block_size=vllm_config.cache_config.block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
)
|
||||||
|
|||||||
@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention):
|
|||||||
attn_type=AttentionType.ENCODER_ONLY,
|
attn_type=AttentionType.ENCODER_ONLY,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
# Does not need KV cache
|
||||||
|
return None
|
||||||
|
|||||||
@ -5,6 +5,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
|
|||||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
"""Get the attention backend class for this layer."""
|
"""Get the attention backend class for this layer."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||||
|
"""
|
||||||
|
Get the KV cache spec for this layer.
|
||||||
|
May be None if the layer does not need KV cache.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
|
|||||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
"""Get the attention backend class for this Mamba layer."""
|
"""Get the attention backend class for this Mamba layer."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||||
|
if (
|
||||||
|
vllm_config.speculative_config is not None
|
||||||
|
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mamba with speculative decoding is not supported yet."
|
||||||
|
)
|
||||||
|
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||||
|
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
|
||||||
|
return MambaSpec(
|
||||||
|
shapes=self.get_state_shape(),
|
||||||
|
dtypes=self.get_state_dtype(),
|
||||||
|
block_size=mamba_block_size,
|
||||||
|
page_size_padded=page_size_padded,
|
||||||
|
mamba_type=self.mamba_type,
|
||||||
|
num_speculative_blocks=(
|
||||||
|
vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
if vllm_config.speculative_config
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
|||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
return MLAAttentionSpec( # Only has one vector instead of K + V
|
return MLAAttentionSpec( # Only has one vector instead of K + V
|
||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
|
|||||||
@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
|
|||||||
torch.set_num_threads(old_num_threads)
|
torch.set_num_threads(old_num_threads)
|
||||||
|
|
||||||
|
|
||||||
|
def kv_cache_dtype_str_to_dtype(
|
||||||
|
kv_cache_dtype: str, model_config: ModelConfig
|
||||||
|
) -> torch.dtype:
|
||||||
|
if kv_cache_dtype == "auto":
|
||||||
|
# Model config may not be specified for unit tests, default to float16
|
||||||
|
return model_config.dtype if model_config else torch.half
|
||||||
|
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
|||||||
@ -948,7 +948,7 @@ class EagleProposer:
|
|||||||
indexer_layers[first_layer]
|
indexer_layers[first_layer]
|
||||||
.get_attn_backend()
|
.get_attn_backend()
|
||||||
.get_builder_cls()(
|
.get_builder_cls()(
|
||||||
indexer_layers[first_layer].get_kv_cache_spec(),
|
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||||
self.indexer_layer_names,
|
self.indexer_layer_names,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
|
|||||||
@ -19,8 +19,6 @@ from tqdm import tqdm
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
|
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
|
||||||
from vllm.attention.layer import MLAAttention
|
|
||||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import (
|
|||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
is_mixture_of_experts,
|
is_mixture_of_experts,
|
||||||
@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (
|
from vllm.utils import (
|
||||||
STR_DTYPE_TO_TORCH_DTYPE,
|
|
||||||
cdiv,
|
cdiv,
|
||||||
check_use_alibi,
|
check_use_alibi,
|
||||||
get_dtype_size,
|
get_dtype_size,
|
||||||
is_pin_memory_available,
|
is_pin_memory_available,
|
||||||
|
kv_cache_dtype_str_to_dtype,
|
||||||
length_from_prompt_token_ids_or_embeds,
|
length_from_prompt_token_ids_or_embeds,
|
||||||
round_up,
|
round_up,
|
||||||
supports_dynamo,
|
supports_dynamo,
|
||||||
@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
KVCacheGroupSpec,
|
KVCacheGroupSpec,
|
||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
MambaSpec,
|
MambaSpec,
|
||||||
MLAAttentionSpec,
|
|
||||||
SlidingWindowSpec,
|
SlidingWindowSpec,
|
||||||
UniformTypeKVCacheSpecs,
|
UniformTypeKVCacheSpecs,
|
||||||
)
|
)
|
||||||
@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
if cache_config.cache_dtype == "auto":
|
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||||
self.kv_cache_dtype = self.dtype
|
cache_config.cache_dtype, self.model_config
|
||||||
else:
|
)
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
||||||
|
|
||||||
self.is_pooling_model = model_config.runner_type == "pooling"
|
self.is_pooling_model = model_config.runner_type == "pooling"
|
||||||
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
||||||
@ -4577,16 +4571,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
format. Layers that do not need KV cache are not included.
|
format. Layers that do not need KV cache are not included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
|
||||||
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||||
for layer_name, attn_module in attn_layers.items():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
if isinstance(attn_module, Attention):
|
if isinstance(attn_module, Attention) and (
|
||||||
if (
|
|
||||||
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||||
) is not None:
|
):
|
||||||
# The layer doesn't need its own KV cache and will use that of
|
# The layer doesn't need its own KV cache and will use that of
|
||||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||||
# that KV cache management logic will act as this layer does
|
# that KV cache management logic will act as this layer does
|
||||||
@ -4596,90 +4586,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# or enable more requests to be processed simultaneously.
|
# or enable more requests to be processed simultaneously.
|
||||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
continue
|
continue
|
||||||
|
# Skip modules that don't need KV cache (eg encoder-only attention)
|
||||||
# TODO(lucas): move the attention specs into the model layers like
|
if spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||||
# the attention backends
|
kv_cache_spec[layer_name] = spec
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
|
||||||
if attn_module.sliding_window is not None:
|
|
||||||
assert not use_mla, "MLA is not supported for slidingwindow"
|
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
sliding_window=attn_module.sliding_window,
|
|
||||||
)
|
|
||||||
elif self.attention_chunk_size is not None and isinstance(
|
|
||||||
attn_module, ChunkedLocalAttention
|
|
||||||
):
|
|
||||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
attention_chunk_size=self.attention_chunk_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
)
|
|
||||||
elif attn_module.attn_type in (
|
|
||||||
AttentionType.ENCODER,
|
|
||||||
AttentionType.ENCODER_ONLY,
|
|
||||||
):
|
|
||||||
# encoder-only attention does not need KV cache.
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
|
||||||
|
|
||||||
elif isinstance(attn_module, MLAAttention):
|
|
||||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=1,
|
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
cache_dtype_str=cache_dtype_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(attn_module, MambaBase):
|
|
||||||
if (
|
|
||||||
self.vllm_config.speculative_config is not None
|
|
||||||
and self.vllm_config.model_config.hf_config.model_type
|
|
||||||
not in ["qwen3_next"]
|
|
||||||
):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Mamba with speculative decoding is not supported yet."
|
|
||||||
)
|
|
||||||
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
|
||||||
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
|
|
||||||
kv_cache_spec[layer_name] = MambaSpec(
|
|
||||||
shapes=attn_module.get_state_shape(),
|
|
||||||
dtypes=attn_module.get_state_dtype(),
|
|
||||||
block_size=mamba_block_size,
|
|
||||||
page_size_padded=page_size_padded,
|
|
||||||
mamba_type=attn_module.mamba_type,
|
|
||||||
num_speculative_blocks=(
|
|
||||||
self.speculative_config.num_speculative_tokens
|
|
||||||
if self.speculative_config
|
|
||||||
else 0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ds_indexer_layers = get_layers_from_vllm_config(
|
|
||||||
self.vllm_config, DeepseekV32IndexerCache
|
|
||||||
)
|
|
||||||
for layer_name, ds_indexer_module in ds_indexer_layers.items():
|
|
||||||
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
|
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user