[Misc] Refactor get_kv_cache_spec into AttentionLayerBase (#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-18 15:51:21 +02:00 committed by GitHub
parent ab4be40fc5
commit b26b70bec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 151 additions and 118 deletions

View File

@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import ( from vllm.distributed.kv_transfer import (
get_kv_transfer_group, get_kv_transfer_group,
has_kv_transfer_group, has_kv_transfer_group,
@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
else: else:
sliding_window = None sliding_window = None
vllm_config = get_current_vllm_config()
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size block_size = cache_config.block_size
@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 block_size = 16
calculate_kv_scales = False calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, ( assert num_heads % num_kv_heads == 0, (
@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.use_direct_call = not current_platform.opaque_attention_op() self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase):
# this variable will not be accessed if use_direct_call is True # this variable will not be accessed if use_direct_call is True
self.kv_cache = [ self.kv_cache = [
torch.tensor([]) torch.tensor([])
for _ in range( for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
] ]
# Initialize q/k/v range constants. # Initialize q/k/v range constants.
@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT.""" """Multi-headed attention without any cache, used for ViT."""
@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
def wait_for_kv_layer_from_connector(layer_name: str): def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): if not has_kv_transfer_group() or not is_v1_kv_transfer_group():

View File

@ -9,6 +9,7 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
from ..layer import Attention from ..layer import Attention
@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name: str | None = None, kv_sharing_target_layer_name: str | None = None,
prefix: str = "", prefix: str = "",
): ):
self.attention_chunk_size = attention_chunk_size
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name=kv_sharing_target_layer_name, kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend, attn_backend=attn_backend,
) )
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
assert self.attention_chunk_size
return ChunkedLocalAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
attention_chunk_size=self.attention_chunk_size,
)

View File

@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.kv_cache_interface import CrossAttentionSpec from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__) logger = init_logger(__name__)
@ -174,3 +174,11 @@ class CrossAttention(Attention):
attn_type=AttentionType.ENCODER_DECODER, attn_type=AttentionType.ENCODER_DECODER,
**kwargs, **kwargs,
) )
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return CrossAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)

View File

@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache @functools.lru_cache
@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention):
attn_type=AttentionType.ENCODER_ONLY, attn_type=AttentionType.ENCODER_ONLY,
**kwargs, **kwargs,
) )
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Does not need KV cache
return None

View File

@ -5,6 +5,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer.""" """Get the attention backend class for this layer."""
pass pass
@abstractmethod
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
"""
Get the KV cache spec for this layer.
May be None if the layer does not need KV cache.
"""
pass

View File

@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer.""" """Get the attention backend class for this Mamba layer."""
pass pass
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
shapes=self.get_state_shape(),
dtypes=self.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
),
)

View File

@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return MLAAttentionSpec( # Only has one vector instead of K + V return MLAAttentionSpec( # Only has one vector instead of K + V
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_kv_heads=1, num_kv_heads=1,

View File

@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
torch.set_num_threads(old_num_threads) torch.set_num_threads(old_num_threads)
def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype:
if kv_cache_dtype == "auto":
# Model config may not be specified for unit tests, default to float16
return model_config.dtype if model_config else torch.half
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")

View File

@ -948,7 +948,7 @@ class EagleProposer:
indexer_layers[first_layer] indexer_layers[first_layer]
.get_attn_backend() .get_attn_backend()
.get_builder_cls()( .get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(), indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names, self.indexer_layer_names,
self.vllm_config, self.vllm_config,
self.device, self.device,

View File

@ -19,8 +19,6 @@ from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMultiModal, SupportsMultiModal,
is_mixture_of_experts, is_mixture_of_experts,
@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import ( from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
cdiv, cdiv,
check_use_alibi, check_use_alibi,
get_dtype_size, get_dtype_size,
is_pin_memory_available, is_pin_memory_available,
kv_cache_dtype_str_to_dtype,
length_from_prompt_token_ids_or_embeds, length_from_prompt_token_ids_or_embeds,
round_up, round_up,
supports_dynamo, supports_dynamo,
@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec, KVCacheGroupSpec,
KVCacheSpec, KVCacheSpec,
MambaSpec, MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec, SlidingWindowSpec,
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device = device self.device = device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto": self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype = self.dtype cache_config.cache_dtype, self.model_config
else: )
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
self.is_pooling_model = model_config.runner_type == "pooling" self.is_pooling_model = model_config.runner_type == "pooling"
self.enable_prompt_embeds = model_config.enable_prompt_embeds self.enable_prompt_embeds = model_config.enable_prompt_embeds
@ -4577,16 +4571,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items(): for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention): if isinstance(attn_module, Attention) and (
if (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name kv_tgt_layer := attn_module.kv_sharing_target_layer_name
) is not None: ):
# The layer doesn't need its own KV cache and will use that of # The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so # the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does # that KV cache management logic will act as this layer does
@ -4596,90 +4586,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# or enable more requests to be processed simultaneously. # or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue continue
# Skip modules that don't need KV cache (eg encoder-only attention)
# TODO(lucas): move the attention specs into the model layers like if spec := attn_module.get_kv_cache_spec(self.vllm_config):
# the attention backends kv_cache_spec[layer_name] = spec
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for slidingwindow"
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
elif isinstance(attn_module, MambaBase):
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
kv_cache_spec[layer_name] = MambaSpec(
shapes=attn_module.get_state_shape(),
dtypes=attn_module.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=attn_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
),
)
ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
)
for layer_name, ds_indexer_module in ds_indexer_layers.items():
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
return kv_cache_spec return kv_cache_spec