diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a2b27ec678e76..eb044d50b5c33 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,9 +17,7 @@ from tqdm import tqdm from typing_extensions import TypeAlias import vllm.envs as envs -from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -34,7 +32,6 @@ from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, @@ -67,8 +64,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, - SlidingWindowSpec) + KVCacheSpec, SlidingWindowSpec) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, @@ -85,7 +81,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_cache_initializer_mixin import KVCacheInitializerMixin +from vllm.v1.worker.kv_cache_mixin import KVCacheInitializerMixin from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -3530,105 +3526,6 @@ class GPUModelRunner(KVCacheInitializerMixin, LoRAModelRunnerMixin, def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), - ) - - return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. diff --git a/vllm/v1/worker/kv_cache_initializer_mixin.py b/vllm/v1/worker/kv_cache_initializer_mixin.py index d3860f8701f0f..9d34249d26a84 100644 --- a/vllm/v1/worker/kv_cache_initializer_mixin.py +++ b/vllm/v1/worker/kv_cache_initializer_mixin.py @@ -9,17 +9,24 @@ from typing import Any, Protocol, cast import torch from vllm.attention import Attention, AttentionType +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.config import get_layers_from_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.utils import get_dtype_size +# yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, EncoderOnlyAttentionSpec, - KVCacheConfig, KVCacheGroupSpec, - KVCacheSpec, MambaSpec) + FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) +# yapf: enable from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.worker.gpu_input_batch import InputBatch @@ -41,6 +48,7 @@ class _KVCacheInitializerSelf(Protocol): is_pooling_model: bool shared_kv_cache_layers: dict[str, str] kv_sharing_fast_prefill_eligible_layers: set[str] + attention_chunk_size: int runner_only_attn_layers: set[str] kv_cache_dtype: torch.dtype kv_cache_config: KVCacheConfig @@ -373,3 +381,104 @@ class KVCacheInitializerMixin: " the softmax lse for decode, but the impl " f"{layer_impl.__class__.__name__} " "does not return the softmax lse for decode.") + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + runner = self._runner() + block_size = runner.vllm_config.cache_config.block_size + use_mla = runner.vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(runner.vllm_config, + Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + runner.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=runner.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + elif runner.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=runner.kv_cache_dtype, + attention_chunk_size=runner.attention_chunk_size, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=runner.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=runner.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + mamba_layers = get_layers_from_vllm_config(runner.vllm_config, + MambaBase) + if len(mamba_layers) > 0: + if (runner.vllm_config.speculative_config is not None + and runner.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if runner.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = runner.vllm_config.model_config.max_model_len + + page_size_padded = ( + runner.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + runner.speculative_config.num_speculative_tokens + if runner.speculative_config else 0), + ) + + return kv_cache_spec