From 315e3f654a49831ad401c90555cf8ffab2f5e489 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 26 Nov 2025 11:34:23 +0800 Subject: [PATCH] Refactor code, make attn backend focus on diffkv and move sink logic to GPUModelRunner Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 4 +- vllm/attention/layer.py | 16 +- vllm/model_executor/models/openpangu.py | 81 +++-- ...lash_sink_attn.py => flash_diffkv_attn.py} | 293 ++++++++++-------- vllm/v1/core/single_type_kv_cache_manager.py | 6 +- vllm/v1/kv_cache_interface.py | 12 +- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 67 ++++ 8 files changed, 277 insertions(+), 205 deletions(-) rename vllm/v1/attention/backends/{flash_sink_attn.py => flash_diffkv_attn.py} (83%) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index e69f1b7ce25e0..596622fe95b16 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,8 +42,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - FLASH_SINK_ATTN = ( - "vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend" + FLASH_DIFFKV_ATTN = ( + "vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend" ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 376101e55e285..629f93981af09 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -285,8 +285,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) - backend_name = self.attn_backend.get_name() - self.backend = AttentionBackendEnum.__members__.get(backend_name) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -902,20 +901,10 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - kwargs = {} - if sink_key is not None or sink_value is not None: - assert sink_key is not None and sink_value is not None, ( - "Currently, it is only supported when " - "sink_key and sink_value are both not None" - ) - kwargs["sink_key"] = sink_key - kwargs["sink_value"] = sink_value self.impl.forward( self, @@ -927,7 +916,6 @@ def unified_attention_with_output( output=output, output_scale=output_scale, output_block_scale=output_block_scale, - **kwargs, ) @@ -937,8 +925,6 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 0486032645ad2..1fe96f71dab64 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -81,12 +81,12 @@ from vllm.model_executor.models.utils import ( ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend +from vllm.transformers_utils.config import set_default_rope_theta +from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend from vllm.v1.kv_cache_interface import ( - FullSinkAttentionSpec, + FullDiffkvAttentionSpec, KVCacheSpec, ) -from vllm.transformers_utils.config import set_default_rope_theta def check_ffn_act_fn(act_fn: str): @@ -96,7 +96,7 @@ def check_ffn_act_fn(act_fn: str): ) -class AttentionWithSink(Attention): +class DiffkvAttention(Attention): def __init__( self, num_heads: int, @@ -138,9 +138,6 @@ class AttentionWithSink(Attention): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - # For attention with sink, we have sink k, v - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_shape: torch.Size | None = None, ) -> torch.Tensor: """ @@ -194,8 +191,6 @@ class AttentionWithSink(Attention): self_kv_cache, attn_metadata, output=output, - sink_key=sink_key, - sink_value=sink_value, ) else: torch.ops.vllm.unified_attention_with_output( @@ -204,13 +199,11 @@ class AttentionWithSink(Attention): value, output, self.layer_name, - sink_key=sink_key, - sink_value=sink_value, ) return output.view(-1, hidden_size) else: raise ValueError( - "Unsupport Error, currently only flash_sink_attn " + "Unsupport Error, currently only flash_diffkv_attn " "backend with output buffer is supported" ) @@ -221,7 +214,7 @@ class AttentionWithSink(Attention): assert self.attn_type == AttentionType.DECODER # Only support for full attention now. assert self.sliding_window is None - return FullSinkAttentionSpec( + return FullDiffkvAttentionSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -682,15 +675,14 @@ class OpenPanguEmbeddedAttention(nn.Module): ) -class OpenPanguSinkAttention(nn.Module): +class OpenPanguDiffkvAttention(nn.Module): def __init__( self, config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: dict[str, Any] | None = None, + rope_parameters: dict[str, Any] | None = None, max_position_embeddings: int = 8192, quant_config: QuantizationConfig | None = None, bias: bool = False, @@ -739,7 +731,6 @@ class OpenPanguSinkAttention(nn.Module): self.k_size = self.num_kv_heads * self.head_dim self.v_size = self.num_kv_heads * self.v_channels self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.param_sink_number = getattr(config, "param_sink_number", 0) @@ -770,7 +761,7 @@ class OpenPanguSinkAttention(nn.Module): self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self._init_rotary_emb( - config, rope_scaling=rope_scaling, quant_config=quant_config + config, rope_parameters=rope_parameters, quant_config=quant_config ) if hasattr(config, "interleaved_sliding_window"): @@ -788,10 +779,8 @@ class OpenPanguSinkAttention(nn.Module): else: sliding_window = None - FlashSinkAttentionBackend.set_cache_head_size_ratio( - (self.head_dim + self.v_channels) / self.head_dim - ) - self.attn = AttentionWithSink( + FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels) + self.attn = DiffkvAttention( self.num_heads, self.head_dim, self.v_channels, @@ -802,7 +791,7 @@ class OpenPanguSinkAttention(nn.Module): per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashSinkAttentionBackend, + attn_backend=FlashDiffkvAttentionBackend, ) if self.param_sink_number > 0: @@ -904,13 +893,6 @@ class OpenPanguSinkAttention(nn.Module): q = q.view(-1, self.q_size) k = k.view(-1, self.k_size) - param_sink_key = self.param_sink_key - if ( - self.param_sink_number > 0 - and hasattr(self, "k_layernorm") - and self.k_layernorm is not None - ): - param_sink_key = self.k_layernorm(param_sink_key) attn_output = self.attn( q, @@ -919,23 +901,14 @@ class OpenPanguSinkAttention(nn.Module): output_shape=torch.Size( [q.shape[0], q.shape[1] // self.head_dim * self.v_channels] ), - **( - dict( - sink_key=param_sink_key, - sink_value=self.param_sink_value, - ) - if self.param_sink_number > 0 - else {} - ), ) - attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels) output, _ = self.o_proj(attn_output) return output def _init_rotary_emb( self, config: PretrainedConfig, - rope_scaling: dict[str, Any] | None, + rope_parameters: dict[str, Any] | None, quant_config: QuantizationConfig | None, ) -> None: is_neox_style = False @@ -944,11 +917,24 @@ class OpenPanguSinkAttention(nn.Module): self.head_dim, rotary_dim=self.qk_rope_dim, max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, + rope_parameters=rope_parameters, is_neox_style=is_neox_style, ) + def get_sink_kv(self) -> dict[str, torch.Tensor]: + if self.param_sink_number == 0: + raise ValueError("No sink_key and sink_value when param_sink_number == 0") + + if hasattr(self, "k_layernorm") and self.k_layernorm is not None: + param_sink_key = self.k_layernorm(self.param_sink_key) + else: + param_sink_key = self.param_sink_key + + return { + "sink_key": param_sink_key, + "sink_value": self.param_sink_value, + } + class OpenPanguDecoderLayer(nn.Module): def __init__( @@ -1011,15 +997,20 @@ class OpenPanguDecoderLayer(nn.Module): f"is_causal={config.is_causal} is not support " "for attention with sink" ) - self.self_attn = OpenPanguSinkAttention( + rope_parameters = getattr(config, "rope_scaling", None) + if rope_parameters is None: + rope_parameters = { + "rope_type": "default", + "rope_theta": config.rope_theta, + } + self.self_attn = OpenPanguDiffkvAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr( config, "num_key_value_heads", config.num_attention_heads ), - rope_theta=rope_theta, - rope_scaling=getattr(config, "rope_scaling", None), + rope_parameters=rope_parameters, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_diffkv_attn.py similarity index 83% rename from vllm/v1/attention/backends/flash_sink_attn.py rename to vllm/v1/attention/backends/flash_diffkv_attn.py index e532a69dcb4a9..acd9cbcb4cabf 100644 --- a/vllm/v1/attention/backends/flash_sink_attn.py +++ b/vllm/v1/attention/backends/flash_diffkv_attn.py @@ -16,6 +16,7 @@ from vllm.attention.backends.abstract import ( is_quantized_kv_cache, ) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash_diffkv, @@ -32,8 +33,9 @@ if is_flash_attn_varlen_func_available(): flash_attn_varlen_func, get_scheduler_metadata, ) -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -54,24 +56,39 @@ from .flash_attn import FlashAttentionMetadata logger = init_logger(__name__) -class FlashSinkAttentionBackend(AttentionBackend): +class FlashDiffkvAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] # TODO: Remove hard code - cache_head_size_ratio: float = 2.0 + head_size_v: int = 128 + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + if ( + model_config + and model_config.is_hybrid + and ( + cache_config.mamba_ssm_cache_dtype == "float32" + or cache_config.mamba_cache_dtype == "float32" + ) + ): + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] + return [MultipleOf(16)] @staticmethod def get_name() -> str: - return "FLASH_SINK_ATTN" + return "FLASH_DIFFKV_ATTN" @classmethod def supports_attn_type(cls, attn_type: str) -> bool: - """FlashSinkAttention supports all attention types.""" + """FlashDiffkvAttention supports all attention types.""" from vllm.attention import AttentionType return attn_type in ( @@ -82,16 +99,16 @@ class FlashSinkAttentionBackend(AttentionBackend): ) @staticmethod - def get_impl_cls() -> type["FlashSinkAttentionImpl"]: - return FlashSinkAttentionImpl + def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]: + return FlashDiffkvAttentionImpl @staticmethod - def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]: - return FlashSinkAttentionMetadataBuilder + def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]: + return FlashDiffkvAttentionMetadataBuilder @classmethod - def set_cache_head_size_ratio(cls, ratio: float) -> None: - cls.cache_head_size_ratio = ratio + def set_head_size_v(cls, head_size_v: int) -> None: + cls.head_size_v = head_size_v @staticmethod def get_kv_cache_shape( @@ -107,16 +124,24 @@ class FlashSinkAttentionBackend(AttentionBackend): num_blocks, block_size, num_kv_heads, - int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio), + head_size + FlashDiffkvAttentionBackend.head_size_v, ) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() - if cache_layout == "NHD": + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, block_size, num_kv_heads, head_size) + return (0, 1, 2, 3, 4) + elif cache_layout == "NHD": stride_order = (0, 1, 2, 3) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 3, 0, 1, 4) elif cache_layout == "HND": stride_order = (0, 2, 1, 3) else: @@ -131,8 +156,8 @@ class FlashSinkAttentionBackend(AttentionBackend): raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: @@ -176,12 +201,12 @@ def _get_sliding_window_configs( sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): - assert isinstance(layer.impl, FlashSinkAttentionImpl) + assert isinstance(layer.impl, FlashDiffkvAttentionImpl) sliding_window_configs.add(layer.impl.sliding_window) return sliding_window_configs -class FlashSinkAttentionMetadataBuilder( +class FlashDiffkvAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata] ): # FA3: @@ -242,8 +267,8 @@ class FlashSinkAttentionMetadataBuilder( self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = ( - self.parallel_config.dcp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size = ( + self.parallel_config.cp_kv_cache_interleave_size ) self.use_full_cuda_graph = ( @@ -322,7 +347,7 @@ class FlashSinkAttentionMetadataBuilder( ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): - qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( cache_dtype ) else: @@ -365,7 +390,7 @@ class FlashSinkAttentionMetadataBuilder( dcp_context_kv_lens_cpu, self.dcp_world_size, self.dcp_rank, - self.dcp_kv_cache_interleave_size, + self.cp_kv_cache_interleave_size, ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() @@ -450,7 +475,7 @@ class FlashSinkAttentionMetadataBuilder( return use_cascade_attention(*args, **kwargs) -class FlashSinkAttentionImpl(AttentionImpl): +class FlashDiffkvAttentionImpl(AttentionImpl): can_return_lse_for_decode: bool = True def __init__( @@ -466,11 +491,9 @@ class FlashSinkAttentionImpl(AttentionImpl): attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, sinks: torch.Tensor | None = None, - head_size_v: int | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size - self.head_size_v = head_size_v self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: @@ -512,7 +535,7 @@ class FlashSinkAttentionImpl(AttentionImpl): ) def supports_quant_query_input(self) -> bool: - return False + return True def forward( self, @@ -525,10 +548,8 @@ class FlashSinkAttentionImpl(AttentionImpl): output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, ) -> torch.Tensor: - """Forward pass with FlashSinkAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] @@ -537,8 +558,6 @@ class FlashSinkAttentionImpl(AttentionImpl): kv_cache: shape = [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. - sink_key: shape = [sink_len, num_kv_heads, head_size] - sink_value: shape = [sink_len, num_kv_heads, head_size_v] Returns: shape = [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of @@ -546,14 +565,11 @@ class FlashSinkAttentionImpl(AttentionImpl): We use torch's .expand() to avoid duplicating values """ assert output is not None, "Output tensor must be provided." - assert sink_key is not None and sink_value is not None, ( - "sink_key and sink_value must be provided" - ) if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not supported yet " - "for FlashSinkAttentionImpl" + "fused output quantization is not yet supported for" + "FlashDiffkvAttentionImpl" ) if attn_metadata is None: @@ -572,7 +588,6 @@ class FlashSinkAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - sink_len = sink_key.shape[0] # Handle encoder attention differently - no KV cache needed if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): @@ -585,11 +600,11 @@ class FlashSinkAttentionImpl(AttentionImpl): output[:num_actual_tokens], attn_metadata, layer, - sink_key, - sink_value, ) # For decoder and cross-attention, use KV cache as before + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -606,29 +621,6 @@ class FlashSinkAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - - # store sink_key and sink_value in head blocks - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] - block_size = key_cache.shape[1] - assert sink_len % block_size == 0 - num_sink_blocks = sink_len // block_size - sink_kv_slot_mapping = torch.arange( - block_size, - sink_len + block_size, - device=attn_metadata.slot_mapping.device, - dtype=attn_metadata.slot_mapping.dtype, - ) - triton_reshape_and_cache_flash_diffkv( - sink_key, - sink_value, - kv_cache, - sink_kv_slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - triton_reshape_and_cache_flash_diffkv( key, value, @@ -641,7 +633,7 @@ class FlashSinkAttentionImpl(AttentionImpl): if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer - dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype ) key_cache = key_cache.view(dtype) @@ -649,29 +641,28 @@ class FlashSinkAttentionImpl(AttentionImpl): if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens + sink_len + seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len + sink_len + max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - sink_block_table = torch.arange( - 1, - num_sink_blocks + 1, - device=block_table.device, - dtype=block_table.dtype, - ) - sink_block_table = sink_block_table[None, :].expand( - block_table.shape[0], -1 - ) - block_table = torch.cat((sink_block_table, block_table), dim=1) descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) if self.dcp_world_size > 1: - raise ValueError( - "Decode context parallel is not supported yet " - f"for dcp_world_size = {self.dcp_world_size}" + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) + return output else: flash_attn_varlen_func( q=query[:num_actual_tokens], @@ -724,10 +715,89 @@ class FlashSinkAttentionImpl(AttentionImpl): k_descale=layer._k_scale, v_descale=layer._v_scale, s_aux=self.sinks, - sink_len=sink_len, ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -736,20 +806,16 @@ class FlashSinkAttentionImpl(AttentionImpl): output: torch.Tensor, attn_metadata: FlashAttentionMetadata, layer: torch.nn.Module, - sink_key: torch.Tensor, - sink_value: torch.Tensor, ) -> torch.Tensor: """Forward pass for encoder attention without KV cache. Args: query: shape = [num_encoder_tokens, num_heads, head_size] key: shape = [num_encoder_tokens, num_kv_heads, head_size] - value: shape = [num_encoder_tokens, num_kv_heads, head_size_v] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] output: shape = [num_encoder_tokens, num_heads, head_size] attn_metadata: Encoder attention metadata layer: The attention layer - sink_key: shape = [sink_len, num_kv_heads, head_size] - sink_value: shape = [sink_len, num_kv_heads, head_size_v] """ # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): @@ -758,40 +824,10 @@ class FlashSinkAttentionImpl(AttentionImpl): ) # Use encoder-specific metadata for sequence information - sink_len = sink_key.shape[0] - key_list = [] - value_list = [] - for seq_id in range(attn_metadata.block_table.shape[0]): - seq_start = attn_metadata.query_start_loc[seq_id] - seq_end = attn_metadata.query_start_loc[seq_id + 1] - key_list.append( - torch.cat( - [ - sink_key, - key[seq_start:seq_end], - ], - dim=0, - ) - ) - value_list.append( - torch.cat( - [ - sink_value, - value[seq_start:seq_end], - ], - dim=0, - ) - ) - key = torch.cat(key_list, dim=0).contiguous() - value = torch.cat(value_list, dim=0).contiguous() - cu_seqlens_q = attn_metadata.query_start_loc - cu_seqlens_k = attn_metadata.seq_lens + sink_len - cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0 - ).to(torch.int32) + cu_seqlens_k = attn_metadata.query_start_loc max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len + sink_len + max_seqlen_k = attn_metadata.max_query_len descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] @@ -926,14 +962,12 @@ def cascade_attention( k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, s_aux: torch.Tensor | None = None, - sink_len: int | None = None, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( "Cascade attention does not support sliding window." ) - assert sink_len is not None, "sink_len must be provided." num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -942,22 +976,15 @@ def cascade_attention( assert num_common_kv_blocks > 0 descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) - num_sink_blocks = sink_len // block_size - sink_block_table = torch.arange( - 1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype - ) - sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) - block_table = torch.cat((sink_block_table, block_table), dim=1) - # Process shared prefix. prefix_output, prefix_lse = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, cu_seqlens_q=cu_prefix_query_lens, - seqused_k=prefix_kv_lens + sink_len, + seqused_k=prefix_kv_lens, max_seqlen_q=num_tokens, - max_seqlen_k=common_prefix_len + sink_len, + max_seqlen_k=common_prefix_len, softmax_scale=softmax_scale, causal=False, window_size=sliding_window, @@ -989,7 +1016,7 @@ def cascade_attention( softmax_scale=softmax_scale, causal=True, window_size=sliding_window, - block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :], + block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, scheduler_metadata=suffix_scheduler_metadata, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index ee5ae21d02843..6267ac0e71f7f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,7 +12,7 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, - FullSinkAttentionSpec, + FullDiffkvAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, @@ -311,7 +311,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, - (FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec), + (FullAttentionSpec, FullDiffkvAttentionSpec, ChunkedLocalAttentionSpec), ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -733,7 +733,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, - FullSinkAttentionSpec: FullAttentionManager, + FullDiffkvAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index aa3ca82a5d4a3..1b130300b2218 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -159,7 +159,7 @@ class FullAttentionSpec(AttentionSpec): @dataclass(frozen=True) -class FullSinkAttentionSpec(AttentionSpec): +class FullDiffkvAttentionSpec(AttentionSpec): head_size_v: int sliding_window: int | None = None attention_chunk_size: int | None = None @@ -170,7 +170,7 @@ class FullSinkAttentionSpec(AttentionSpec): window attention are regarded as full attention in KV cache manager (blocks are allocated for all tokens), while computed as sliding window attention in model runner. - In this case, we use FullSinkAttentionSpec and record the sliding window size. + In this case, we use FullDiffkvAttentionSpec and record the sliding window size. Default to None for not using sliding window attention. """ @@ -198,12 +198,12 @@ class FullSinkAttentionSpec(AttentionSpec): @classmethod def merge(cls, specs: list[Self]) -> Self: """ - Merge a list of FullSinkAttentionSpec objects into a single - FullSinkAttentionSpec object. + Merge a list of FullDiffkvAttentionSpec objects into a single + FullDiffkvAttentionSpec object. """ - assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), ( + assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), ( "All attention layers in the same KV cache group must be " - "FullSinkAttentionSpec." + "FullDiffkvAttentionSpec." ) sliding_window = set( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e7991baeaa1b8..5ec918654677c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -90,6 +90,7 @@ class InputBatch: is_pooling_model: bool = False, num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, + sink_len: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -136,7 +137,7 @@ class InputBatch: # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, + max_model_len=max_model_len + sink_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ce6c4a3204b0..c24e0561215ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,6 +25,9 @@ from vllm.attention.backends.abstract import ( AttentionMetadata, MultipleOf, ) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -321,6 +324,10 @@ class GPUModelRunner( self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + self.sink_len = getattr( + self.vllm_config.model_config.hf_config, "param_sink_number", 0 + ) + assert self.sink_len % self.cache_config.block_size == 0 # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = model_config.uses_alibi @@ -443,6 +450,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, + sink_len=self.sink_len, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1590,6 +1598,28 @@ class GPUModelRunner( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + # Modify the blk_table_tensor and seq_lens in-place so that attention will + # know there are sink_key and sink_value in kv_caches + if self.sink_len > 0: + seq_lens[:] = seq_lens + self.sink_len + seq_lens_cpu[:] = seq_lens_cpu + self.sink_len + max_seq_len = max_seq_len + self.sink_len + sink_block_table = torch.arange( + 1, + self.sink_len // self.cache_config.block_size + 1, + device=blk_table_tensor.device, + dtype=blk_table_tensor.dtype, + ) + sink_block_table = sink_block_table[None, :].expand( + blk_table_tensor.shape[0], -1 + ) + num_sink_blocks = sink_block_table.shape[1] + blk_table_tensor_clone = blk_table_tensor.clone() + blk_table_tensor[:, num_sink_blocks:] = blk_table_tensor_clone[ + :, :-num_sink_blocks + ] + blk_table_tensor[:, :num_sink_blocks] = sink_block_table + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1624,6 +1654,8 @@ class GPUModelRunner( if cascade_attn_prefix_lens else 0 ) + if self.sink_len > 0: + cascade_attn_prefix_len = cascade_attn_prefix_len + self.sink_len builder = attn_group.get_metadata_builder() extra_attn_metadata_args = {} @@ -4838,6 +4870,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, + sink_len=self.sink_len, ) def _allocate_kv_cache_tensors( @@ -5165,6 +5198,7 @@ class GPUModelRunner( kv_caches = self.initialize_kv_cache_tensors( kv_cache_config, kernel_block_sizes ) + self.prepare_sink_kv_cache(kv_caches) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -5269,3 +5303,36 @@ class GPUModelRunner( self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() + + def prepare_sink_kv_cache(self, kv_caches) -> None: + if self.sink_len == 0: + return + + def find_module_by_name(model, target_name: str): + for name, module in model.named_modules(): + if name == target_name: + return module + raise KeyError(f"Module '{target_name}' not found") + + for layer_name, kv_cache in kv_caches.item(): + layer_prefix = layer_name.rsplit(".", 1)[0] + self_attn_module = find_module_by_name(self.model, layer_prefix) + if not hasattr(self_attn_module, "get_sink_kv"): + continue + else: + sink_kv = self_attn_module.get_sink_kv() + sink_kv_slot_mapping = torch.arange( + self.vllm_config.cache_config.block_size, + self.sink_len + self.vllm_config.cache_config.block_size, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + triton_reshape_and_cache_flash_diffkv( + sink_kv["sink_key"], + sink_kv["sink_value"], + kv_cache, + sink_kv_slot_mapping, + self_attn_module.attn.kv_cache_dtype, + self_attn_module.attn._k_scale, + self_attn_module.attn._v_scale, + )