diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e288770f2fcb..f23fc9ba5288 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config +from vllm.config.vllm import VllmConfig from vllm.distributed.kv_transfer import ( get_kv_transfer_group, has_kv_transfer_group, @@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import ( + direct_register_custom_op, + kv_cache_dtype_str_to_dtype, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) @@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase): else: sliding_window = None + vllm_config = get_current_vllm_config() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size @@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase): kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + kv_cache_dtype, vllm_config.model_config + ) if num_kv_heads is None: num_kv_heads = num_heads assert num_heads % num_kv_heads == 0, ( @@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase): self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer - compilation_config = get_current_vllm_config().compilation_config + compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase): # this variable will not be accessed if use_direct_call is True self.kv_cache = [ torch.tensor([]) - for _ in range( - get_current_vllm_config().parallel_config.pipeline_parallel_size - ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] # Initialize q/k/v range constants. @@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase): def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + if self.sliding_window is not None: + assert not vllm_config.model_config.use_mla, ( + "MLA is not supported for slidingwindow" + ) + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + sliding_window=self.sliding_window, + ) + else: + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" @@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase): def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + kv_cache_dtype = kv_cache_dtype_str_to_dtype( + self.kv_cache_dtype, vllm_config.model_config + ) + return MLAAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_size, + dtype=kv_cache_dtype, + cache_dtype_str=vllm_config.cache_config.cache_dtype, + ) + def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index d1f9a0437aa6..18422404d08f 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -9,6 +9,7 @@ from vllm import envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.config.vllm import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import ( make_local_attention_virtual_batches, subclass_attention_backend, ) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec from ..layer import Attention @@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention): kv_sharing_target_layer_name: str | None = None, prefix: str = "", ): + self.attention_chunk_size = attention_chunk_size dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention): kv_sharing_target_layer_name=kv_sharing_target_layer_name, attn_backend=attn_backend, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + assert self.attention_chunk_size + return ChunkedLocalAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + attention_chunk_size=self.attention_chunk_size, + ) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index b07ffcc5ffeb..a40a66308a66 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, subclass_attention_backend, ) -from vllm.v1.kv_cache_interface import CrossAttentionSpec +from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec logger = init_logger(__name__) @@ -174,3 +174,11 @@ class CrossAttention(Attention): attn_type=AttentionType.ENCODER_DECODER, **kwargs, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return CrossAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 1a47135d03a7..8d2a046757fe 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.config.vllm import VllmConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, subclass_attention_backend, ) +from vllm.v1.kv_cache_interface import KVCacheSpec @functools.lru_cache @@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention): attn_type=AttentionType.ENCODER_ONLY, **kwargs, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Does not need KV cache + return None diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index fa74c20840da..ffbef470b186 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -5,6 +5,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import KVCacheSpec + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -22,3 +25,11 @@ class AttentionLayerBase(ABC): def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this layer.""" pass + + @abstractmethod + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + """ + Get the KV cache spec for this layer. + May be None if the layer does not need KV cache. + """ + pass diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 6da62b5426bb..e68b09b4d81f 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -6,7 +6,9 @@ from typing import TYPE_CHECKING import torch +from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase): def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this Mamba layer.""" pass + + @abstractmethod + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + pass + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + if ( + vllm_config.speculative_config is not None + and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"] + ): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = vllm_config.cache_config.mamba_block_size + page_size_padded = vllm_config.cache_config.mamba_page_size_padded + return MambaSpec( + shapes=self.get_state_shape(), + dtypes=self.get_state_dtype(), + block_size=mamba_block_size, + page_size_padded=page_size_padded, + mamba_type=self.mamba_type, + num_speculative_blocks=( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ), + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5b55b685dacf..8ad85357aee5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index f5569154be7f..dd83f9fc9655 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) +def kv_cache_dtype_str_to_dtype( + kv_cache_dtype: str, model_config: ModelConfig +) -> torch.dtype: + if kv_cache_dtype == "auto": + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + + T = TypeVar("T") U = TypeVar("U") diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 60418e53c174..206b13ad5164 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -948,7 +948,7 @@ class EagleProposer: indexer_layers[first_layer] .get_attn_backend() .get_builder_cls()( - indexer_layers[first_layer].get_kv_cache_spec(), + indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config), self.indexer_layer_names, self.vllm_config, self.device, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a9874b164f42..72fc00f6edcc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,8 +19,6 @@ from tqdm import tqdm import vllm.envs as envs from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend, MultipleOf -from vllm.attention.layer import MLAAttention -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.interfaces import ( SupportsMultiModal, is_mixture_of_experts, @@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import ( - STR_DTYPE_TO_TORCH_DTYPE, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, + kv_cache_dtype_str_to_dtype, length_from_prompt_token_ids_or_embeds, round_up, supports_dynamo, @@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import ( KVCacheGroupSpec, KVCacheSpec, MambaSpec, - MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( + cache_config.cache_dtype, self.model_config + ) self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds @@ -4577,109 +4571,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): - if isinstance(attn_module, Attention): - if ( - kv_tgt_layer := attn_module.kv_sharing_target_layer_name - ) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for slidingwindow" - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) - elif self.attention_chunk_size is not None and isinstance( - attn_module, ChunkedLocalAttention - ): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - ) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - - elif isinstance(attn_module, MLAAttention): - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str, - ) - - elif isinstance(attn_module, MambaBase): - if ( - self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"] - ): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet." - ) - mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded - kv_cache_spec[layer_name] = MambaSpec( - shapes=attn_module.get_state_shape(), - dtypes=attn_module.get_state_dtype(), - block_size=mamba_block_size, - page_size_padded=page_size_padded, - mamba_type=attn_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ), - ) - - ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache - ) - for layer_name, ds_indexer_module in ds_indexer_layers.items(): - kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() + if isinstance(attn_module, Attention) and ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ): + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + # Skip modules that don't need KV cache (eg encoder-only attention) + if spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec return kv_cache_spec