From 8de4315229f6da4eb9d29cbceb2849033ff3418a Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 15 Nov 2025 12:00:40 +0800 Subject: [PATCH 01/16] Add support for openpangu_pro_moe_v2, which characterized by its different kv head size and sink kv in attention. Signed-off-by: yuantao <2422264527@qq.com> --- docs/models/supported_models.md | 1 + tests/models/registry.py | 5 + vllm/attention/layer.py | 45 +- .../ops/triton_reshape_and_cache_flash.py | 133 +++ vllm/model_executor/models/openpangu.py | 465 +++++++- vllm/model_executor/models/registry.py | 1 + vllm/v1/attention/backends/flash_sink_attn.py | 1005 +++++++++++++++++ vllm/v1/core/sched/scheduler.py | 5 + vllm/v1/core/single_type_kv_cache_manager.py | 5 +- vllm/v1/kv_cache_interface.py | 92 ++ 10 files changed, 1744 insertions(+), 13 deletions(-) create mode 100644 vllm/v1/attention/backends/flash_sink_attn.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9cdf644c3cc52..2ce1cd9a943bf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -427,6 +427,7 @@ th { | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | | `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ | +| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ | | `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 644d0619215fb..b8b9dc9c43799 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,6 +383,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PanguEmbeddedForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True ), + "PanguProMoEV2ForCausalLM": _HfExamplesInfo( + "", + trust_remote_code=True, + is_available_online=False, + ), "PanguUltraMoEForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", trust_remote_code=True, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 37f9a4b383ce9..deabf7a0b0770 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -939,21 +939,42 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) + if sink_key is None and sink_value is None: + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + else: + assert sink_key is not None and sink_value is not None, ( + "Currently, it is only supported when " + "sink_key and sink_value are both not None" + ) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + sink_key=sink_key, + sink_value=sink_value, + ) def unified_attention_with_output_fake( @@ -962,6 +983,8 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index 5d2ba154ae018..d79e209e303b0 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -182,3 +182,136 @@ def triton_reshape_and_cache_flash( num_warps=num_warps, num_stages=num_stages, ) + + +@triton.jit +def reshape_and_cache_kernel_flash_diffkv( + kv_ptr, # [num_tokens, num_heads, head_size + head_size_v] + kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + kv_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size_kv: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_kv_idx = token_idx * kv_stride + + tgt_idx = block_idx * block_stride + block_offset * page_stride + + # [TILE_SIZE] + kv_tile = tl.load( + kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv) + ) + + tl.store( + kv_cache_ptr + tgt_idx + tile_pos, + kv_tile, + mask=tile_pos < (num_heads * head_size_kv), + ) + return + + +def triton_reshape_and_cache_flash_diffkv( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size_v] + # [num_blocks, block_size, num_heads, head_size + head_size_v] + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + kv = torch.cat([key, value], dim=-1).contiguous() + num_heads = kv.shape[1] + head_size_kv = kv.shape[2] + block_size = kv_cache.shape[1] + n = num_heads * head_size_kv + + kv_stride = kv.stride()[0] + block_stride = kv_cache.stride()[0] + page_stride = kv_cache.stride()[1] + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + ) + kv_cache_torch_dtype = ( + current_platform.fp8_dtype() + if kv_cache_dtype.startswith("fp8") + else kv_cache.dtype + ) + + if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + kv_cache = kv_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, ( + "explicit fp8 cast and store to " + "uint8 is not supported by triton reshape_and_cache_flash_diffkv" + ) + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert not FP8_KV_CACHE, ( + "unsupported dtype of KV cache tensor, got " + "{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32." + ) + + # heuristics instead of autotuning + TILE_SIZE = min(2048, triton.next_power_of_2(n)) + if current_platform.is_rocm() or current_platform.is_xpu(): + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + if torch.cuda.get_device_capability(key.device)[0] < 9: + TILE_SIZE = min(512, TILE_SIZE) + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: ( + slot_mapping.shape[0], + triton.cdiv(n, meta["TILE_SIZE"]), + ) + + reshape_and_cache_kernel_flash_diffkv[grid]( + kv_ptr=kv, + kv_cache_ptr=kv_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + kv_stride=kv_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size_kv=head_size_kv, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index d13a745beffeb..f7249475b49de 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -30,15 +30,18 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType +from vllm.attention.backends.abstract import AttentionBackend from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import ( get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_gather, ) +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -76,7 +79,13 @@ from vllm.model_executor.models.utils import ( maybe_prefix, sequence_parallel_chunk, ) +from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend +from vllm.v1.kv_cache_interface import ( + FullSinkAttentionSpec, + KVCacheSpec, +) def check_ffn_act_fn(act_fn: str): @@ -86,6 +95,140 @@ def check_ffn_act_fn(act_fn: str): ) +class AttentionWithSink(Attention): + def __init__( + self, + num_heads: int, + head_size: int, + head_size_v: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + logits_soft_cap: float | None = None, + per_layer_sliding_window: int | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + attn_backend: type[AttentionBackend] | None = None, + **extra_impl_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + cache_config, + quant_config, + logits_soft_cap, + per_layer_sliding_window, + prefix, + attn_type, + kv_sharing_target_layer_name, + attn_backend, + **extra_impl_args, + ) + self.head_size_v = head_size_v + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For attention with sink, we have sink k, v + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) + output_dtype = query.dtype + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} + + # check if query quantization is supported + if self.impl.supports_quant_query_input(): + query, _ = self.query_quant(query, self._q_scale) + + if self.use_output: + output_shape = output_shape if output_shape is not None else query.shape + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + hidden_size = output_shape[-1] + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size_v) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, + sink_key=sink_key, + sink_value=sink_value, + ) + else: + torch.ops.vllm.unified_attention_with_output( + query, + key, + value, + output, + self.layer_name, + sink_key=sink_key, + sink_value=sink_value, + ) + return output.view(-1, hidden_size) + else: + raise ValueError( + "Unsupport Error, currently only flash_sink_attn " + "backend with output buffer is supported" + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + # Only support for full attention now. + assert self.sliding_window is None + return FullSinkAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + ) + + class OpenPanguMLP(nn.Module): def __init__( self, @@ -153,7 +296,15 @@ class OpenPanguMoE(nn.Module): quant_config=None, prefix=f"{prefix}.gate", ) - self.gate.e_score_correction_bias = None + if ( + hasattr(config, "router_enable_expert_bias") + and config.router_enable_expert_bias + ): + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None # Load balancing settings. eplb_config = parallel_config.eplb_config @@ -539,6 +690,276 @@ class OpenPanguEmbeddedAttention(nn.Module): ) +class OpenPanguSinkAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + if self.total_num_heads % self.tp_size != 0: + raise ValueError( + f"total_num_heads {self.total_num_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if ( + self.total_num_kv_heads > self.tp_size + and self.total_num_kv_heads % self.tp_size != 0 + ): + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + "Number of KV heads is greater than TP size, " + f"but total_num_kv_heads {self.total_num_kv_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + elif self.total_num_kv_heads < self.tp_size: + # TODO: Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + f"Number of KV heads {self.total_num_kv_heads} is less than " + f"TP size {self.tp_size}, KV heads replication is not support yet." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.qk_nope_dim = getattr(config, "qk_nope_dim", None) + self.qk_rope_dim = getattr(config, "qk_rope_dim", None) + self.v_channels = getattr(config, "v_channels", None) + self.head_dim = self.qk_rope_dim + self.qk_nope_dim + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_channels + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.param_sink_number = getattr(config, "param_sink_number", 0) + self.param_sink_with_value = getattr(config, "param_sink_with_value", False) + self.param_sink_scalar = getattr(config, "param_sink_scalar", None) + self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False) + + self.qkv_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + self.q_size * self.tp_size, + self.k_size * self.tp_size, + self.v_size * self.tp_size, + ], + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.v_channels, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) + + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} " + "for interleaved_sliding_window is not supported." + ) + else: + sliding_window = None + + FlashSinkAttentionBackend.set_cache_head_size_ratio( + (self.head_dim + self.v_channels) / self.head_dim + ) + self.attn = AttentionWithSink( + self.num_heads, + self.head_dim, + self.v_channels, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + attn_backend=FlashSinkAttentionBackend, + ) + + if self.param_sink_number > 0: + self.param_sink_key = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.head_dim, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_key, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if self.param_sink_with_value: + self.param_sink_value = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_value, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + else: + self.param_sink_value = torch.zeros( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, nn.UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + assert final_shape[output_dim] % self.tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim)) + q, k = self.rotary_emb(positions, q, k) + + q = q.view(-1, self.q_size) + k = k.view(-1, self.k_size) + param_sink_key = self.param_sink_key + if ( + self.param_sink_number > 0 + and hasattr(self, "k_layernorm") + and self.k_layernorm is not None + ): + param_sink_key = self.k_layernorm(param_sink_key) + + attn_output = self.attn( + q, + k, + v, + output_shape=torch.Size( + [q.shape[0], q.shape[1] // self.head_dim * self.v_channels] + ), + **( + dict( + sink_key=param_sink_key, + sink_value=self.param_sink_value, + ) + if self.param_sink_number > 0 + else {} + ), + ) + attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb( + self, + config: PretrainedConfig, + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, + ) -> None: + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.qk_rope_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + class OpenPanguDecoderLayer(nn.Module): def __init__( self, @@ -567,6 +988,9 @@ class OpenPanguDecoderLayer(nn.Module): and hasattr(config, "v_head_dim") and hasattr(config, "kv_lora_rank") ) + self.use_sink_attention = ( + hasattr(config, "param_sink_number") and config.param_sink_number > 0 + ) if self.use_mla: self.self_attn = OpenPanguMLAAttention( config=config, @@ -585,6 +1009,37 @@ class OpenPanguDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + elif self.use_sink_attention: + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + raise ValueError( + f"is_causal={config.is_causal} is not support " + "for attention with sink" + ) + self.self_attn = OpenPanguSinkAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=getattr(config, "rope_scaling", None), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) else: attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False @@ -916,6 +1371,10 @@ class OpenPanguModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) + if name.endswith("e_score_correction_bias"): + name = name.replace( + "e_score_correction_bias", "gate.e_score_correction_bias" + ) if name is None: continue if is_pp_missing_parameter(name, self): @@ -1060,3 +1519,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel): class PanguUltraMoEForCausalLM(OpenPanguMoEModel): pass + + +class PanguProMoEV2ForCausalLM(OpenPanguMoEModel): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4af8fa01f562b..c50be12883897 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -150,6 +150,7 @@ _TEXT_GENERATION_MODELS = { "OrionForCausalLM": ("orion", "OrionForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"), "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"), + "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"), "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_sink_attn.py new file mode 100644 index 0000000000000..580f0f544102c --- /dev/null +++ b/vllm/v1/attention/backends/flash_sink_attn.py @@ -0,0 +1,1005 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +from typing import ClassVar + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_sinks, + flash_attn_varlen_func, + get_scheduler_metadata, + ) +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_dcp_local_seq_lens, + get_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +from .flash_attn import FlashAttentionMetadata + +logger = init_logger(__name__) + + +class FlashSinkAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + # TODO: Remove hard code + cache_head_size_ratio: float = 2.0 + + @staticmethod + def get_name() -> str: + return "FLASH_SINK_ATTN" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlashSinkAttention supports all attention types.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + + @staticmethod + def get_impl_cls() -> type["FlashSinkAttentionImpl"]: + return FlashSinkAttentionImpl + + @staticmethod + def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]: + return FlashSinkAttentionMetadataBuilder + + @classmethod + def set_cache_head_size_ratio(cls, ratio: float) -> None: + cls.cache_head_size_ratio = ratio + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return ( + num_blocks, + block_size, + num_kv_heads, + int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio), + ) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3) + elif cache_layout == "HND": + stride_order = (0, 2, 1, 3) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + @staticmethod + def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_sink(cls) -> bool: + if not is_flash_attn_varlen_func_available(): + return False + return flash_attn_supports_sinks() + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + + +def _get_sliding_window_configs( + vllm_config: VllmConfig, +) -> set[tuple[int, int] | None]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashSinkAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + +class FlashSinkAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata] +): + # FA3: + # Supports full cudagraphs for all cases. + # + # FA2: + # For FA2, a graph is captured with max_query_len=1, (which is what we + # capture by default for num_tokens <= max_num_seqs when there is no + # spec-decode) then these graphs will not work for mixed prefill-decode + # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling + # in FA2. + # In summary if we are running with spec decodes the graphs would + # work for mixed prefill-decode and uniform-decode. But for non-spec decodes + # the graphs would not work for mixed prefill-decode; sorta the inverse + # of UNIFORM_SINGLE_TOKEN_DECODE. + # There's probably a better way to describe this using `AttentionCGSupport` + # but for now just set it to `UNIFORM_BATCH` to get use to drop down + # to FULL_AND_PIECEWISE. + # TODO(luka, lucas): audit FA2 as part of: + # https://github.com/vllm-project/vllm/issues/22945 + _cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.kv_cache_dtype = kv_cache_spec.dtype + self.headdim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.max_num_splits = 0 # No upper bound on the number of splits. + self.aot_schedule = get_flash_attn_version() == 3 + + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + self.dcp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size + + if self.use_full_cuda_graph and self.aot_schedule: + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: tuple[int, int] | None = None + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if aot_schedule: + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + aot_schedule = False + + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + if vllm_is_batch_invariant(): + max_num_splits = 1 + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype + ) + else: + qkv_dtype = self.kv_cache_dtype + if aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads_q * self.dcp_world_size, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, + cu_seqlens_q=cu_query_lens, + page_size=self.block_size, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=max_num_splits, + ) + return None + + use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + + dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() + + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( + self.device, non_blocking=True + ) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) + else: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) + # For FA3 + full cudagraph + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, + causal=causal, + ) + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + + +class FlashSinkAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, + head_size_v: int | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.head_size_v = head_size_v + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device." + ) + + self.sinks = sinks + if self.sinks is not None: + assert flash_attn_supports_sinks(), ( + "Sinks are only supported in FlashAttention 3" + ) + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer" + ) + + def supports_quant_query_input(self) -> bool: + return False + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashSinkAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size_v] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads, head_size + head_size_v] + attn_metadata: Metadata for attention. + sink_key: shape = [sink_len, num_kv_heads, head_size] + sink_value: shape = [sink_len, num_kv_heads, head_size_v] + Returns: + shape = [num_tokens, num_heads * head_size_v] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + assert sink_key is not None and sink_value is not None, ( + "sink_key and sink_value must be provided" + ) + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not supported yet " + "for FlashSinkAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + sink_len = sink_key.shape[0] + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + sink_key, + sink_value, + ) + + # For decoder and cross-attention, use KV cache as before + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + + # store sink_key and sink_value in head blocks + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] + block_size = key_cache.shape[1] + assert sink_len % block_size == 0 + num_sink_blocks = sink_len // block_size + sink_kv_slot_mapping = torch.arange( + sink_len, + device=attn_metadata.slot_mapping.device, + dtype=attn_metadata.slot_mapping.dtype, + ) + triton_reshape_and_cache_flash_diffkv( + sink_key, + sink_value, + kv_cache, + sink_kv_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + sink_len + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + sink_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + sink_block_table = torch.arange( + num_sink_blocks, device=block_table.device, dtype=block_table.dtype + ) + sink_block_table = sink_block_table[None, :].expand( + block_table.shape[0], -1 + ) + block_table = torch.cat((sink_block_table, block_table), dim=1) + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + raise ValueError( + "Decode context parallel is not supported yet " + f"for dcp_world_size = {self.dcp_world_size}" + ) + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + sink_len=sink_len, + ) + return output + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + sink_key: torch.Tensor, + sink_value: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size_v] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + sink_key: shape = [sink_len, num_kv_heads, head_size] + sink_value: shape = [sink_len, num_kv_heads, head_size_v] + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + sink_len = sink_key.shape[0] + key_list = [] + value_list = [] + for seq_id in range(attn_metadata.block_table.shape[0]): + seq_start = attn_metadata.query_start_loc[seq_id] + seq_end = attn_metadata.query_start_loc[seq_id + 1] + key_list.append( + torch.cat( + [ + sink_key, + key[seq_start:seq_end], + ], + dim=0, + ) + ) + value_list.append( + torch.cat( + [ + sink_value, + value[seq_start:seq_end], + ], + dim=0, + ) + ) + key = torch.cat(key_list, dim=0).contiguous() + value = torch.cat(value_list, dim=0).contiguous() + + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.seq_lens + sink_len + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0 + ).to(torch.int32) + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + sink_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads, + ) + + # Call flash attention directly on Q, K, V tensors + flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=1 if self.batch_invariant_enabled else 0, + ) + + return output + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + use_local_attention: bool, + num_sms: int, + dcp_world_size: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window or use_local_attention: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) + if not use_flash_decoding: + # Use cascade attention. + return True + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: torch.Tensor | None, + sliding_window: tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + max_num_splits: int, + fa_version: int, + prefix_scheduler_metadata: torch.Tensor | None = None, + suffix_scheduler_metadata: torch.Tensor | None = None, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, + sink_len: int | None = None, +) -> torch.Tensor: + assert alibi_slopes is None, "Cascade attention does not support ALiBi." + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window." + ) + assert sink_len is not None, "sink_len must be provided." + + num_tokens = query.shape[0] + block_size = key_cache.shape[-3] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) + + num_sink_blocks = sink_len // block_size + block_table = block_table + num_sink_blocks + block_table[block_table == num_sink_blocks] = 0 + sink_block_table = ( + torch.arange( + num_sink_blocks, device=block_table.device, dtype=block_table.dtype + ) + + 1 + ) + sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) + block_table = torch.cat((sink_block_table, block_table), dim=1) + + # Process shared prefix. + prefix_output, prefix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + seqused_k=prefix_kv_lens + sink_len, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len + sink_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, + fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # s_aux is incorporated into prefix_lse inside the GPU kernel, + # enabling its effect during the final attention merge. + s_aux=s_aux, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + ) + + descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) + + # Process suffix per query. + suffix_output, suffix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + seqused_k=suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :], + softcap=logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, + fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + ) + + # Merge prefix and suffix outputs, and store the result in output. + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c640c40a455d0..041393d097b9a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -184,6 +184,11 @@ class Scheduler(SchedulerInterface): enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, ) + sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0) + if sink_len > 0: + assert sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 14ac83028ee44..5aba227fa0176 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,6 +12,7 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, + FullSinkAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, @@ -305,7 +306,8 @@ class FullAttentionManager(SingleTypeKVCacheManager): dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( - kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) + kv_cache_spec, + (FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec), ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -720,6 +722,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + FullSinkAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7f33eb7e699c7..5ff99acac52a2 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -157,6 +157,98 @@ class FullAttentionSpec(AttentionSpec): return merged_spec +@dataclass(frozen=True) +class FullSinkAttentionSpec(AttentionSpec): + head_size_v: int + sliding_window: int | None = None + attention_chunk_size: int | None = None + + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullSinkAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size." + ) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullSinkAttentionSpec objects into a single + FullSinkAttentionSpec object. + """ + assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "FullSinkAttentionSpec." + ) + + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, + dtype=specs[0].dtype, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) + return merged_spec + + @property + def page_size_bytes(self) -> int: + return ( + self.block_size + * self.num_kv_heads + * (self.head_size + self.head_size_v) + * get_dtype_size(self.dtype) + ) + + @dataclass(frozen=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this From b0e880632ae9352923dfd832ab31aa7acaf08762 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 15 Nov 2025 15:36:53 +0800 Subject: [PATCH 02/16] Bugfix for param_sink_key initialization, block_table for cascade and refactor forward in unified_attention_with_output Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/layer.py | 43 ++++++++----------- vllm/model_executor/models/openpangu.py | 16 +++---- vllm/v1/attention/backends/flash_sink_attn.py | 17 ++++---- 3 files changed, 32 insertions(+), 44 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index deabf7a0b0770..6962810bdd09f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -945,36 +945,27 @@ def unified_attention_with_output( output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - if sink_key is None and sink_value is None: - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) - else: + kwargs = {} + if sink_key is not None or sink_value is not None: assert sink_key is not None and sink_value is not None, ( "Currently, it is only supported when " "sink_key and sink_value are both not None" ) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - sink_key=sink_key, - sink_value=sink_value, - ) + kwargs["sink_key"] = sink_key + kwargs["sink_value"] = sink_value + + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + **kwargs, + ) def unified_attention_with_output_fake( diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index f7249475b49de..f46fd3c7f319d 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module): ) else: self.param_sink_value = torch.zeros( - torch.empty( - ( - self.param_sink_number, - self.num_kv_heads, - self.v_channels, - ), - device=torch.cuda.current_device(), - dtype=config.torch_dtype, - ) + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_sink_attn.py index 580f0f544102c..e532a69dcb4a9 100644 --- a/vllm/v1/attention/backends/flash_sink_attn.py +++ b/vllm/v1/attention/backends/flash_sink_attn.py @@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl): assert sink_len % block_size == 0 num_sink_blocks = sink_len // block_size sink_kv_slot_mapping = torch.arange( - sink_len, + block_size, + sink_len + block_size, device=attn_metadata.slot_mapping.device, dtype=attn_metadata.slot_mapping.dtype, ) @@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata sink_block_table = torch.arange( - num_sink_blocks, device=block_table.device, dtype=block_table.dtype + 1, + num_sink_blocks + 1, + device=block_table.device, + dtype=block_table.dtype, ) sink_block_table = sink_block_table[None, :].expand( block_table.shape[0], -1 @@ -939,13 +943,8 @@ def cascade_attention( descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) num_sink_blocks = sink_len // block_size - block_table = block_table + num_sink_blocks - block_table[block_table == num_sink_blocks] = 0 - sink_block_table = ( - torch.arange( - num_sink_blocks, device=block_table.device, dtype=block_table.dtype - ) - + 1 + sink_block_table = torch.arange( + 1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype ) sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) block_table = torch.cat((sink_block_table, block_table), dim=1) From e38739aefc2ef97a697d3bfe5c4bbd4e14787231 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 15 Nov 2025 18:37:53 +0800 Subject: [PATCH 03/16] Add FLASH_SINK_ATTN to AttentionBackendEnum Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index f07a6059be377..0602607966720 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,6 +42,9 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + FLASH_SINK_ATTN = ( + "vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend" + ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" From 315e3f654a49831ad401c90555cf8ffab2f5e489 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 26 Nov 2025 11:34:23 +0800 Subject: [PATCH 04/16] Refactor code, make attn backend focus on diffkv and move sink logic to GPUModelRunner Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 4 +- vllm/attention/layer.py | 16 +- vllm/model_executor/models/openpangu.py | 81 +++-- ...lash_sink_attn.py => flash_diffkv_attn.py} | 293 ++++++++++-------- vllm/v1/core/single_type_kv_cache_manager.py | 6 +- vllm/v1/kv_cache_interface.py | 12 +- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 67 ++++ 8 files changed, 277 insertions(+), 205 deletions(-) rename vllm/v1/attention/backends/{flash_sink_attn.py => flash_diffkv_attn.py} (83%) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index e69f1b7ce25e0..596622fe95b16 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,8 +42,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - FLASH_SINK_ATTN = ( - "vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend" + FLASH_DIFFKV_ATTN = ( + "vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend" ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 376101e55e285..629f93981af09 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -285,8 +285,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) - backend_name = self.attn_backend.get_name() - self.backend = AttentionBackendEnum.__members__.get(backend_name) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -902,20 +901,10 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - kwargs = {} - if sink_key is not None or sink_value is not None: - assert sink_key is not None and sink_value is not None, ( - "Currently, it is only supported when " - "sink_key and sink_value are both not None" - ) - kwargs["sink_key"] = sink_key - kwargs["sink_value"] = sink_value self.impl.forward( self, @@ -927,7 +916,6 @@ def unified_attention_with_output( output=output, output_scale=output_scale, output_block_scale=output_block_scale, - **kwargs, ) @@ -937,8 +925,6 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 0486032645ad2..1fe96f71dab64 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -81,12 +81,12 @@ from vllm.model_executor.models.utils import ( ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend +from vllm.transformers_utils.config import set_default_rope_theta +from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend from vllm.v1.kv_cache_interface import ( - FullSinkAttentionSpec, + FullDiffkvAttentionSpec, KVCacheSpec, ) -from vllm.transformers_utils.config import set_default_rope_theta def check_ffn_act_fn(act_fn: str): @@ -96,7 +96,7 @@ def check_ffn_act_fn(act_fn: str): ) -class AttentionWithSink(Attention): +class DiffkvAttention(Attention): def __init__( self, num_heads: int, @@ -138,9 +138,6 @@ class AttentionWithSink(Attention): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - # For attention with sink, we have sink k, v - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, output_shape: torch.Size | None = None, ) -> torch.Tensor: """ @@ -194,8 +191,6 @@ class AttentionWithSink(Attention): self_kv_cache, attn_metadata, output=output, - sink_key=sink_key, - sink_value=sink_value, ) else: torch.ops.vllm.unified_attention_with_output( @@ -204,13 +199,11 @@ class AttentionWithSink(Attention): value, output, self.layer_name, - sink_key=sink_key, - sink_value=sink_value, ) return output.view(-1, hidden_size) else: raise ValueError( - "Unsupport Error, currently only flash_sink_attn " + "Unsupport Error, currently only flash_diffkv_attn " "backend with output buffer is supported" ) @@ -221,7 +214,7 @@ class AttentionWithSink(Attention): assert self.attn_type == AttentionType.DECODER # Only support for full attention now. assert self.sliding_window is None - return FullSinkAttentionSpec( + return FullDiffkvAttentionSpec( block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, @@ -682,15 +675,14 @@ class OpenPanguEmbeddedAttention(nn.Module): ) -class OpenPanguSinkAttention(nn.Module): +class OpenPanguDiffkvAttention(nn.Module): def __init__( self, config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: dict[str, Any] | None = None, + rope_parameters: dict[str, Any] | None = None, max_position_embeddings: int = 8192, quant_config: QuantizationConfig | None = None, bias: bool = False, @@ -739,7 +731,6 @@ class OpenPanguSinkAttention(nn.Module): self.k_size = self.num_kv_heads * self.head_dim self.v_size = self.num_kv_heads * self.v_channels self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.param_sink_number = getattr(config, "param_sink_number", 0) @@ -770,7 +761,7 @@ class OpenPanguSinkAttention(nn.Module): self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self._init_rotary_emb( - config, rope_scaling=rope_scaling, quant_config=quant_config + config, rope_parameters=rope_parameters, quant_config=quant_config ) if hasattr(config, "interleaved_sliding_window"): @@ -788,10 +779,8 @@ class OpenPanguSinkAttention(nn.Module): else: sliding_window = None - FlashSinkAttentionBackend.set_cache_head_size_ratio( - (self.head_dim + self.v_channels) / self.head_dim - ) - self.attn = AttentionWithSink( + FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels) + self.attn = DiffkvAttention( self.num_heads, self.head_dim, self.v_channels, @@ -802,7 +791,7 @@ class OpenPanguSinkAttention(nn.Module): per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashSinkAttentionBackend, + attn_backend=FlashDiffkvAttentionBackend, ) if self.param_sink_number > 0: @@ -904,13 +893,6 @@ class OpenPanguSinkAttention(nn.Module): q = q.view(-1, self.q_size) k = k.view(-1, self.k_size) - param_sink_key = self.param_sink_key - if ( - self.param_sink_number > 0 - and hasattr(self, "k_layernorm") - and self.k_layernorm is not None - ): - param_sink_key = self.k_layernorm(param_sink_key) attn_output = self.attn( q, @@ -919,23 +901,14 @@ class OpenPanguSinkAttention(nn.Module): output_shape=torch.Size( [q.shape[0], q.shape[1] // self.head_dim * self.v_channels] ), - **( - dict( - sink_key=param_sink_key, - sink_value=self.param_sink_value, - ) - if self.param_sink_number > 0 - else {} - ), ) - attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels) output, _ = self.o_proj(attn_output) return output def _init_rotary_emb( self, config: PretrainedConfig, - rope_scaling: dict[str, Any] | None, + rope_parameters: dict[str, Any] | None, quant_config: QuantizationConfig | None, ) -> None: is_neox_style = False @@ -944,11 +917,24 @@ class OpenPanguSinkAttention(nn.Module): self.head_dim, rotary_dim=self.qk_rope_dim, max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, + rope_parameters=rope_parameters, is_neox_style=is_neox_style, ) + def get_sink_kv(self) -> dict[str, torch.Tensor]: + if self.param_sink_number == 0: + raise ValueError("No sink_key and sink_value when param_sink_number == 0") + + if hasattr(self, "k_layernorm") and self.k_layernorm is not None: + param_sink_key = self.k_layernorm(self.param_sink_key) + else: + param_sink_key = self.param_sink_key + + return { + "sink_key": param_sink_key, + "sink_value": self.param_sink_value, + } + class OpenPanguDecoderLayer(nn.Module): def __init__( @@ -1011,15 +997,20 @@ class OpenPanguDecoderLayer(nn.Module): f"is_causal={config.is_causal} is not support " "for attention with sink" ) - self.self_attn = OpenPanguSinkAttention( + rope_parameters = getattr(config, "rope_scaling", None) + if rope_parameters is None: + rope_parameters = { + "rope_type": "default", + "rope_theta": config.rope_theta, + } + self.self_attn = OpenPanguDiffkvAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr( config, "num_key_value_heads", config.num_attention_heads ), - rope_theta=rope_theta, - rope_scaling=getattr(config, "rope_scaling", None), + rope_parameters=rope_parameters, max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_diffkv_attn.py similarity index 83% rename from vllm/v1/attention/backends/flash_sink_attn.py rename to vllm/v1/attention/backends/flash_diffkv_attn.py index e532a69dcb4a9..acd9cbcb4cabf 100644 --- a/vllm/v1/attention/backends/flash_sink_attn.py +++ b/vllm/v1/attention/backends/flash_diffkv_attn.py @@ -16,6 +16,7 @@ from vllm.attention.backends.abstract import ( is_quantized_kv_cache, ) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash_diffkv, @@ -32,8 +33,9 @@ if is_flash_attn_varlen_func_available(): flash_attn_varlen_func, get_scheduler_metadata, ) -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -54,24 +56,39 @@ from .flash_attn import FlashAttentionMetadata logger = init_logger(__name__) -class FlashSinkAttentionBackend(AttentionBackend): +class FlashDiffkvAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] # TODO: Remove hard code - cache_head_size_ratio: float = 2.0 + head_size_v: int = 128 + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + if ( + model_config + and model_config.is_hybrid + and ( + cache_config.mamba_ssm_cache_dtype == "float32" + or cache_config.mamba_cache_dtype == "float32" + ) + ): + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] + return [MultipleOf(16)] @staticmethod def get_name() -> str: - return "FLASH_SINK_ATTN" + return "FLASH_DIFFKV_ATTN" @classmethod def supports_attn_type(cls, attn_type: str) -> bool: - """FlashSinkAttention supports all attention types.""" + """FlashDiffkvAttention supports all attention types.""" from vllm.attention import AttentionType return attn_type in ( @@ -82,16 +99,16 @@ class FlashSinkAttentionBackend(AttentionBackend): ) @staticmethod - def get_impl_cls() -> type["FlashSinkAttentionImpl"]: - return FlashSinkAttentionImpl + def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]: + return FlashDiffkvAttentionImpl @staticmethod - def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]: - return FlashSinkAttentionMetadataBuilder + def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]: + return FlashDiffkvAttentionMetadataBuilder @classmethod - def set_cache_head_size_ratio(cls, ratio: float) -> None: - cls.cache_head_size_ratio = ratio + def set_head_size_v(cls, head_size_v: int) -> None: + cls.head_size_v = head_size_v @staticmethod def get_kv_cache_shape( @@ -107,16 +124,24 @@ class FlashSinkAttentionBackend(AttentionBackend): num_blocks, block_size, num_kv_heads, - int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio), + head_size + FlashDiffkvAttentionBackend.head_size_v, ) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() - if cache_layout == "NHD": + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, block_size, num_kv_heads, head_size) + return (0, 1, 2, 3, 4) + elif cache_layout == "NHD": stride_order = (0, 1, 2, 3) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 3, 0, 1, 4) elif cache_layout == "HND": stride_order = (0, 2, 1, 3) else: @@ -131,8 +156,8 @@ class FlashSinkAttentionBackend(AttentionBackend): raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + def supports_head_size(cls, head_size: int) -> bool: + return head_size % 8 == 0 and head_size <= 256 @classmethod def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: @@ -176,12 +201,12 @@ def _get_sliding_window_configs( sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): - assert isinstance(layer.impl, FlashSinkAttentionImpl) + assert isinstance(layer.impl, FlashDiffkvAttentionImpl) sliding_window_configs.add(layer.impl.sliding_window) return sliding_window_configs -class FlashSinkAttentionMetadataBuilder( +class FlashDiffkvAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata] ): # FA3: @@ -242,8 +267,8 @@ class FlashSinkAttentionMetadataBuilder( self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = ( - self.parallel_config.dcp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size = ( + self.parallel_config.cp_kv_cache_interleave_size ) self.use_full_cuda_graph = ( @@ -322,7 +347,7 @@ class FlashSinkAttentionMetadataBuilder( ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): - qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( cache_dtype ) else: @@ -365,7 +390,7 @@ class FlashSinkAttentionMetadataBuilder( dcp_context_kv_lens_cpu, self.dcp_world_size, self.dcp_rank, - self.dcp_kv_cache_interleave_size, + self.cp_kv_cache_interleave_size, ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() @@ -450,7 +475,7 @@ class FlashSinkAttentionMetadataBuilder( return use_cascade_attention(*args, **kwargs) -class FlashSinkAttentionImpl(AttentionImpl): +class FlashDiffkvAttentionImpl(AttentionImpl): can_return_lse_for_decode: bool = True def __init__( @@ -466,11 +491,9 @@ class FlashSinkAttentionImpl(AttentionImpl): attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, sinks: torch.Tensor | None = None, - head_size_v: int | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size - self.head_size_v = head_size_v self.scale = float(scale) self.num_kv_heads = num_kv_heads if alibi_slopes is not None: @@ -512,7 +535,7 @@ class FlashSinkAttentionImpl(AttentionImpl): ) def supports_quant_query_input(self) -> bool: - return False + return True def forward( self, @@ -525,10 +548,8 @@ class FlashSinkAttentionImpl(AttentionImpl): output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, - sink_key: torch.Tensor | None = None, - sink_value: torch.Tensor | None = None, ) -> torch.Tensor: - """Forward pass with FlashSinkAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads, head_size] @@ -537,8 +558,6 @@ class FlashSinkAttentionImpl(AttentionImpl): kv_cache: shape = [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. - sink_key: shape = [sink_len, num_kv_heads, head_size] - sink_value: shape = [sink_len, num_kv_heads, head_size_v] Returns: shape = [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of @@ -546,14 +565,11 @@ class FlashSinkAttentionImpl(AttentionImpl): We use torch's .expand() to avoid duplicating values """ assert output is not None, "Output tensor must be provided." - assert sink_key is not None and sink_value is not None, ( - "sink_key and sink_value must be provided" - ) if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not supported yet " - "for FlashSinkAttentionImpl" + "fused output quantization is not yet supported for" + "FlashDiffkvAttentionImpl" ) if attn_metadata is None: @@ -572,7 +588,6 @@ class FlashSinkAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - sink_len = sink_key.shape[0] # Handle encoder attention differently - no KV cache needed if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): @@ -585,11 +600,11 @@ class FlashSinkAttentionImpl(AttentionImpl): output[:num_actual_tokens], attn_metadata, layer, - sink_key, - sink_value, ) # For decoder and cross-attention, use KV cache as before + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -606,29 +621,6 @@ class FlashSinkAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - - # store sink_key and sink_value in head blocks - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] - block_size = key_cache.shape[1] - assert sink_len % block_size == 0 - num_sink_blocks = sink_len // block_size - sink_kv_slot_mapping = torch.arange( - block_size, - sink_len + block_size, - device=attn_metadata.slot_mapping.device, - dtype=attn_metadata.slot_mapping.dtype, - ) - triton_reshape_and_cache_flash_diffkv( - sink_key, - sink_value, - kv_cache, - sink_kv_slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - triton_reshape_and_cache_flash_diffkv( key, value, @@ -641,7 +633,7 @@ class FlashSinkAttentionImpl(AttentionImpl): if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer - dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( self.kv_cache_dtype ) key_cache = key_cache.view(dtype) @@ -649,29 +641,28 @@ class FlashSinkAttentionImpl(AttentionImpl): if not attn_metadata.use_cascade: cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens + sink_len + seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len + sink_len + max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - sink_block_table = torch.arange( - 1, - num_sink_blocks + 1, - device=block_table.device, - dtype=block_table.dtype, - ) - sink_block_table = sink_block_table[None, :].expand( - block_table.shape[0], -1 - ) - block_table = torch.cat((sink_block_table, block_table), dim=1) descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) if self.dcp_world_size > 1: - raise ValueError( - "Decode context parallel is not supported yet " - f"for dcp_world_size = {self.dcp_world_size}" + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) + return output else: flash_attn_varlen_func( q=query[:num_actual_tokens], @@ -724,10 +715,89 @@ class FlashSinkAttentionImpl(AttentionImpl): k_descale=layer._k_scale, v_descale=layer._v_scale, s_aux=self.sinks, - sink_len=sink_len, ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -736,20 +806,16 @@ class FlashSinkAttentionImpl(AttentionImpl): output: torch.Tensor, attn_metadata: FlashAttentionMetadata, layer: torch.nn.Module, - sink_key: torch.Tensor, - sink_value: torch.Tensor, ) -> torch.Tensor: """Forward pass for encoder attention without KV cache. Args: query: shape = [num_encoder_tokens, num_heads, head_size] key: shape = [num_encoder_tokens, num_kv_heads, head_size] - value: shape = [num_encoder_tokens, num_kv_heads, head_size_v] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] output: shape = [num_encoder_tokens, num_heads, head_size] attn_metadata: Encoder attention metadata layer: The attention layer - sink_key: shape = [sink_len, num_kv_heads, head_size] - sink_value: shape = [sink_len, num_kv_heads, head_size_v] """ # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): @@ -758,40 +824,10 @@ class FlashSinkAttentionImpl(AttentionImpl): ) # Use encoder-specific metadata for sequence information - sink_len = sink_key.shape[0] - key_list = [] - value_list = [] - for seq_id in range(attn_metadata.block_table.shape[0]): - seq_start = attn_metadata.query_start_loc[seq_id] - seq_end = attn_metadata.query_start_loc[seq_id + 1] - key_list.append( - torch.cat( - [ - sink_key, - key[seq_start:seq_end], - ], - dim=0, - ) - ) - value_list.append( - torch.cat( - [ - sink_value, - value[seq_start:seq_end], - ], - dim=0, - ) - ) - key = torch.cat(key_list, dim=0).contiguous() - value = torch.cat(value_list, dim=0).contiguous() - cu_seqlens_q = attn_metadata.query_start_loc - cu_seqlens_k = attn_metadata.seq_lens + sink_len - cu_seqlens_k = torch.nn.functional.pad( - torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0 - ).to(torch.int32) + cu_seqlens_k = attn_metadata.query_start_loc max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len + sink_len + max_seqlen_k = attn_metadata.max_query_len descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] @@ -926,14 +962,12 @@ def cascade_attention( k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, s_aux: torch.Tensor | None = None, - sink_len: int | None = None, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( "Cascade attention does not support sliding window." ) - assert sink_len is not None, "sink_len must be provided." num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -942,22 +976,15 @@ def cascade_attention( assert num_common_kv_blocks > 0 descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) - num_sink_blocks = sink_len // block_size - sink_block_table = torch.arange( - 1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype - ) - sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) - block_table = torch.cat((sink_block_table, block_table), dim=1) - # Process shared prefix. prefix_output, prefix_lse = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, cu_seqlens_q=cu_prefix_query_lens, - seqused_k=prefix_kv_lens + sink_len, + seqused_k=prefix_kv_lens, max_seqlen_q=num_tokens, - max_seqlen_k=common_prefix_len + sink_len, + max_seqlen_k=common_prefix_len, softmax_scale=softmax_scale, causal=False, window_size=sliding_window, @@ -989,7 +1016,7 @@ def cascade_attention( softmax_scale=softmax_scale, causal=True, window_size=sliding_window, - block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :], + block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, scheduler_metadata=suffix_scheduler_metadata, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index ee5ae21d02843..6267ac0e71f7f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,7 +12,7 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, - FullSinkAttentionSpec, + FullDiffkvAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, @@ -311,7 +311,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, - (FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec), + (FullAttentionSpec, FullDiffkvAttentionSpec, ChunkedLocalAttentionSpec), ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -733,7 +733,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, - FullSinkAttentionSpec: FullAttentionManager, + FullDiffkvAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index aa3ca82a5d4a3..1b130300b2218 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -159,7 +159,7 @@ class FullAttentionSpec(AttentionSpec): @dataclass(frozen=True) -class FullSinkAttentionSpec(AttentionSpec): +class FullDiffkvAttentionSpec(AttentionSpec): head_size_v: int sliding_window: int | None = None attention_chunk_size: int | None = None @@ -170,7 +170,7 @@ class FullSinkAttentionSpec(AttentionSpec): window attention are regarded as full attention in KV cache manager (blocks are allocated for all tokens), while computed as sliding window attention in model runner. - In this case, we use FullSinkAttentionSpec and record the sliding window size. + In this case, we use FullDiffkvAttentionSpec and record the sliding window size. Default to None for not using sliding window attention. """ @@ -198,12 +198,12 @@ class FullSinkAttentionSpec(AttentionSpec): @classmethod def merge(cls, specs: list[Self]) -> Self: """ - Merge a list of FullSinkAttentionSpec objects into a single - FullSinkAttentionSpec object. + Merge a list of FullDiffkvAttentionSpec objects into a single + FullDiffkvAttentionSpec object. """ - assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), ( + assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), ( "All attention layers in the same KV cache group must be " - "FullSinkAttentionSpec." + "FullDiffkvAttentionSpec." ) sliding_window = set( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e7991baeaa1b8..5ec918654677c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -90,6 +90,7 @@ class InputBatch: is_pooling_model: bool = False, num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, + sink_len: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -136,7 +137,7 @@ class InputBatch: # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, + max_model_len=max_model_len + sink_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ce6c4a3204b0..c24e0561215ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,6 +25,9 @@ from vllm.attention.backends.abstract import ( AttentionMetadata, MultipleOf, ) +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -321,6 +324,10 @@ class GPUModelRunner( self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + self.sink_len = getattr( + self.vllm_config.model_config.hf_config, "param_sink_number", 0 + ) + assert self.sink_len % self.cache_config.block_size == 0 # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = model_config.uses_alibi @@ -443,6 +450,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, + sink_len=self.sink_len, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1590,6 +1598,28 @@ class GPUModelRunner( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + # Modify the blk_table_tensor and seq_lens in-place so that attention will + # know there are sink_key and sink_value in kv_caches + if self.sink_len > 0: + seq_lens[:] = seq_lens + self.sink_len + seq_lens_cpu[:] = seq_lens_cpu + self.sink_len + max_seq_len = max_seq_len + self.sink_len + sink_block_table = torch.arange( + 1, + self.sink_len // self.cache_config.block_size + 1, + device=blk_table_tensor.device, + dtype=blk_table_tensor.dtype, + ) + sink_block_table = sink_block_table[None, :].expand( + blk_table_tensor.shape[0], -1 + ) + num_sink_blocks = sink_block_table.shape[1] + blk_table_tensor_clone = blk_table_tensor.clone() + blk_table_tensor[:, num_sink_blocks:] = blk_table_tensor_clone[ + :, :-num_sink_blocks + ] + blk_table_tensor[:, :num_sink_blocks] = sink_block_table + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1624,6 +1654,8 @@ class GPUModelRunner( if cascade_attn_prefix_lens else 0 ) + if self.sink_len > 0: + cascade_attn_prefix_len = cascade_attn_prefix_len + self.sink_len builder = attn_group.get_metadata_builder() extra_attn_metadata_args = {} @@ -4838,6 +4870,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, + sink_len=self.sink_len, ) def _allocate_kv_cache_tensors( @@ -5165,6 +5198,7 @@ class GPUModelRunner( kv_caches = self.initialize_kv_cache_tensors( kv_cache_config, kernel_block_sizes ) + self.prepare_sink_kv_cache(kv_caches) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -5269,3 +5303,36 @@ class GPUModelRunner( self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() + + def prepare_sink_kv_cache(self, kv_caches) -> None: + if self.sink_len == 0: + return + + def find_module_by_name(model, target_name: str): + for name, module in model.named_modules(): + if name == target_name: + return module + raise KeyError(f"Module '{target_name}' not found") + + for layer_name, kv_cache in kv_caches.item(): + layer_prefix = layer_name.rsplit(".", 1)[0] + self_attn_module = find_module_by_name(self.model, layer_prefix) + if not hasattr(self_attn_module, "get_sink_kv"): + continue + else: + sink_kv = self_attn_module.get_sink_kv() + sink_kv_slot_mapping = torch.arange( + self.vllm_config.cache_config.block_size, + self.sink_len + self.vllm_config.cache_config.block_size, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + triton_reshape_and_cache_flash_diffkv( + sink_kv["sink_key"], + sink_kv["sink_value"], + kv_cache, + sink_kv_slot_mapping, + self_attn_module.attn.kv_cache_dtype, + self_attn_module.attn._k_scale, + self_attn_module.attn._v_scale, + ) From ca7eaa045caa5f35aac0b0c51b9118edec182e68 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 26 Nov 2025 11:57:11 +0800 Subject: [PATCH 05/16] fix pre-commit Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 38da13a83f3ac..8444ee5ef425f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -318,7 +318,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, - FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec + FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec, ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" From 0abdaad39b9e84182225a55e99f9e4ba88abfca5 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 26 Nov 2025 14:40:57 +0800 Subject: [PATCH 06/16] fix typo Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 704dad7068ea6..e02793c5a8d25 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5319,7 +5319,7 @@ class GPUModelRunner( return module raise KeyError(f"Module '{target_name}' not found") - for layer_name, kv_cache in kv_caches.item(): + for layer_name, kv_cache in kv_caches.items(): layer_prefix = layer_name.rsplit(".", 1)[0] self_attn_module = find_module_by_name(self.model, layer_prefix) if not hasattr(self_attn_module, "get_sink_kv"): From b565203d926267a48266fd1a60bac80553bb80ff Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 13 Dec 2025 15:47:33 +0800 Subject: [PATCH 07/16] Refacotr code. Extent FLASH_ATTN to support different KV size and create a new StaticSinkAttention for sink token logics Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 3 - vllm/attention/layer.py | 11 +- .../attention/layers/static_sink_attention.py | 225 ++++ .../ops/triton_reshape_and_cache_flash.py | 90 +- vllm/model_executor/models/openpangu.py | 169 +-- vllm/v1/attention/backends/flash_attn.py | 81 +- .../attention/backends/flash_diffkv_attn.py | 1031 ----------------- vllm/v1/core/sched/scheduler.py | 5 - vllm/v1/core/single_type_kv_cache_manager.py | 28 +- vllm/v1/kv_cache_interface.py | 137 +-- vllm/v1/worker/block_table.py | 3 +- vllm/v1/worker/gpu/attn_utils.py | 11 +- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 48 +- .../worker/kv_connector_model_runner_mixin.py | 20 +- 15 files changed, 503 insertions(+), 1362 deletions(-) create mode 100644 vllm/attention/layers/static_sink_attention.py delete mode 100644 vllm/v1/attention/backends/flash_diffkv_attn.py diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 630858fc2193a..eaa0fa1d5db39 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,9 +42,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - FLASH_DIFFKV_ATTN = ( - "vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend" - ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5c43dae35b812..6c4cc0085432c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -191,6 +191,7 @@ class Attention(nn.Module, AttentionLayerBase): attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + head_size_v: int | None = None, **extra_impl_args, ) -> None: """ @@ -232,6 +233,7 @@ class Attention(nn.Module, AttentionLayerBase): self.num_heads = num_heads self.head_size = head_size + self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None @@ -370,6 +372,10 @@ class Attention(nn.Module, AttentionLayerBase): query, _ = self.query_quant(query, self._q_scale) if self.use_output: + if output_shape is None: + output_shape = torch.Size( + (*query.shape[:-1], self.num_heads * self.head_size_v) + ) output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] @@ -377,11 +383,11 @@ class Attention(nn.Module, AttentionLayerBase): # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size_v) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -456,6 +462,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, + head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, ) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py new file mode 100644 index 0000000000000..7687651ee682b --- /dev/null +++ b/vllm/attention/layers/static_sink_attention.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + SinkFullAttentionSpec, +) + +logger = init_logger(__name__) + + +@functools.lru_cache +def create_static_sink_attention_backend( + underlying_attn_backend: type[AttentionBackend], + sink_len: int = 0, +) -> type[AttentionBackend]: + prefix = "StaticSink_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class StaticSinkAttentionBuilder(underlying_builder): # type: ignore + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.sink_len = sink_len + self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size + self.sink_block_table = torch.arange( + 1, + self.num_sink_blocks + 1, + device=device, + dtype=torch.int32, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + common_attn_metadata.seq_lens[:] = ( + common_attn_metadata.seq_lens + self.sink_len + ) + common_attn_metadata.seq_lens_cpu = ( + common_attn_metadata.seq_lens_cpu + self.sink_len + ) + common_attn_metadata.max_seq_len = ( + common_attn_metadata.max_seq_len + self.sink_len + ) + + blk_table_tensor = common_attn_metadata.block_table_tensor + sink_block_table = self.sink_block_table[None, :].expand( + blk_table_tensor.shape[0], -1 + ) + blk_table_tensor_clone = blk_table_tensor.clone() + blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[ + :, : -self.num_sink_blocks + ] + blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table + + return super().build(common_prefix_len, common_attn_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=StaticSinkAttentionBuilder, + ) + + return attn_backend + + +class StaticSinkAttention(Attention): + """ + Attention with static sink tokens + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + sink_len: int, + attn_backend: type[AttentionBackend] | None = None, + cache_config: CacheConfig | None = None, + **kwargs, + ): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if attn_backend is not None: + underlying_attn_backend = attn_backend + else: + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_static_sink_attention_backend( + underlying_attn_backend, + sink_len=sink_len, + ) + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + **kwargs, + ) + + self.sink_len = sink_len + self.block_size = block_size + self.sink_populated = False + self.sink_key = None + self.sink_value = None + + def update_sink_kv(self, sink_key, sink_value) -> None: + self.sink_key = sink_key + self.sink_value = sink_value + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.size | None = None, + ) -> torch.Tensor: + assert self.sink_key is not None and self.sink_value is not None, ( + "sink_key and sink_value have not been prepared" + ) + if not self.sink_populated: + forward_context: ForwardContext = get_forward_context() + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) + + return super().forward(query, key, value, output_shape) + + def populate_sink_kv(self, self_kv_cache): + sink_kv_slot_mapping = torch.arange( + self.block_size, + self.sink_len + self.block_size, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + triton_reshape_and_cache_flash_diffkv( + self.sink_key, + self.sink_value, + self_kv_cache, + sink_kv_slot_mapping, + self.kv_cache_dtype, + self._k_scale, + self._v_scale, + ) + # We only populate the sink_key and sink_value once + self.sink_populated = True + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + + return SinkFullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + sink_len=self.sink_len, + dtype=self.kv_cache_torch_dtype, + ) + + +def maybe_populate_sink( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if self.sink_populated or self_kv_cache.numel() == 0: + return + self.populate_sink_kv(self_kv_cache) + + +def maybe_populate_sink_fake( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="maybe_populate_sink", + op_func=maybe_populate_sink, + mutates_args=["self_kv_cache"], + fake_impl=maybe_populate_sink_fake, +) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index d79e209e303b0..c119033896ec6 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -186,17 +186,20 @@ def triton_reshape_and_cache_flash( @triton.jit def reshape_and_cache_kernel_flash_diffkv( - kv_ptr, # [num_tokens, num_heads, head_size + head_size_v] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size_v] kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v] slot_mapping_ptr, # [num_tokens] k_scale, # float32 v_scale, # float32 # strides - kv_stride: tl.int64, + key_stride: tl.int64, + value_stride: tl.int64, block_stride: tl.int64, page_stride: tl.int64, num_heads: tl.constexpr, - head_size_kv: tl.constexpr, + head_size_k: tl.constexpr, + head_size_v: tl.constexpr, block_size: tl.constexpr, # FP8 flags FP8_KV_CACHE: tl.constexpr, @@ -211,24 +214,51 @@ def reshape_and_cache_kernel_flash_diffkv( tile_i = tl.program_id(axis=1) tile_offs = tl.arange(0, TILE_SIZE) - tile_pos = tile_i * TILE_SIZE + tile_offs block_idx = slot_idx // block_size block_offset = slot_idx % block_size - src_kv_idx = token_idx * kv_stride + src_key_idx = token_idx * key_stride + tile_i * head_size_k + src_value_idx = token_idx * value_stride + tile_i * head_size_v - tgt_idx = block_idx * block_stride + block_offset * page_stride - - # [TILE_SIZE] - kv_tile = tl.load( - kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv) + tgt_idx = ( + block_idx * block_stride + + block_offset * page_stride + + tile_i * (head_size_k + head_size_v) ) + # [TILE_SIZE] + key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k) + if FP8_KV_CACHE: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load( + value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v + ) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + tl.store( - kv_cache_ptr + tgt_idx + tile_pos, - kv_tile, - mask=tile_pos < (num_heads * head_size_kv), + kv_cache_ptr + tgt_idx + tile_offs, + key_tile, + mask=tile_offs < head_size_k, + ) + tl.store( + kv_cache_ptr + tgt_idx + head_size_k + tile_offs, + value_tile, + mask=tile_offs < head_size_v, ) return @@ -243,13 +273,13 @@ def triton_reshape_and_cache_flash_diffkv( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): - kv = torch.cat([key, value], dim=-1).contiguous() - num_heads = kv.shape[1] - head_size_kv = kv.shape[2] + num_heads = key.shape[1] + head_size_k = key.shape[2] + head_size_v = value.shape[2] block_size = kv_cache.shape[1] - n = num_heads * head_size_kv - kv_stride = kv.stride()[0] + k_stride = key.stride()[0] + v_stride = value.stride()[0] block_stride = kv_cache.stride()[0] page_stride = kv_cache.stride()[1] @@ -272,13 +302,20 @@ def triton_reshape_and_cache_flash_diffkv( ) FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") - assert not FP8_KV_CACHE, ( + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + torch.float8_e4m3fnuz, + ], ( "unsupported dtype of KV cache tensor, got " - "{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32." + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." ) # heuristics instead of autotuning - TILE_SIZE = min(2048, triton.next_power_of_2(n)) + TILE_SIZE = max(head_size_k, head_size_v) + TILE_SIZE = triton.next_power_of_2(TILE_SIZE) if current_platform.is_rocm() or current_platform.is_xpu(): num_stages = 4 num_warps = 8 @@ -292,21 +329,24 @@ def triton_reshape_and_cache_flash_diffkv( # using cudagraphs grid = lambda meta: ( slot_mapping.shape[0], - triton.cdiv(n, meta["TILE_SIZE"]), + num_heads, ) reshape_and_cache_kernel_flash_diffkv[grid]( - kv_ptr=kv, + key_ptr=key, + value_ptr=value, kv_cache_ptr=kv_cache, slot_mapping_ptr=slot_mapping, k_scale=k_scale, v_scale=v_scale, # strides - kv_stride=kv_stride, + key_stride=k_stride, + value_stride=v_stride, block_stride=block_stride, page_stride=page_stride, num_heads=num_heads, - head_size_kv=head_size_kv, + head_size_k=head_size_k, + head_size_v=head_size_v, block_size=block_size, # FP8 flags FP8_KV_CACHE=FP8_KV_CACHE, diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 619981eeccd7c..8e4bb62e137a8 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -29,8 +29,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, AttentionType +from vllm.attention.layers.static_sink_attention import StaticSinkAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import ( @@ -41,7 +41,6 @@ from vllm.distributed import ( get_tp_group, tensor_model_parallel_all_gather, ) -from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -82,11 +81,7 @@ from vllm.model_executor.models.utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta -from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend -from vllm.v1.kv_cache_interface import ( - FullDiffkvAttentionSpec, - KVCacheSpec, -) +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend def check_ffn_act_fn(act_fn: str): @@ -96,133 +91,6 @@ def check_ffn_act_fn(act_fn: str): ) -class DiffkvAttention(Attention): - def __init__( - self, - num_heads: int, - head_size: int, - head_size_v: int, - scale: float, - num_kv_heads: int | None = None, - alibi_slopes: list[float] | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - logits_soft_cap: float | None = None, - per_layer_sliding_window: int | None = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - attn_backend: type[AttentionBackend] | None = None, - **extra_impl_args, - ) -> None: - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - cache_config, - quant_config, - logits_soft_cap, - per_layer_sliding_window, - prefix, - attn_type, - kv_sharing_target_layer_name, - attn_backend, - **extra_impl_args, - ) - self.head_size_v = head_size_v - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output_shape: torch.Size | None = None, - ) -> torch.Tensor: - """ - The KV cache is stored inside this class and is accessed via - `self.kv_cache`. - - Attention metadata (`attn_metadata`) is set using a context manager in - the model runner's `execute_model` method. It is accessed via forward - context using - `vllm.forward_context.get_forward_context().attn_metadata`. - """ - if self.calculate_kv_scales: - torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) - output_dtype = query.dtype - if self.query_quant is not None: - # quantizing with a simple torch operation enables - # torch.compile to fuse this into previous ops - # which reduces overheads during decoding. - # Otherwise queries are quantized using custom ops - # which causes decoding overheads - assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} - - # check if query quantization is supported - if self.impl.supports_quant_query_input(): - query, _ = self.query_quant(query, self._q_scale) - - if self.use_output: - output_shape = output_shape if output_shape is not None else query.shape - output = torch.empty(output_shape, dtype=output_dtype, device=query.device) - hidden_size = output_shape[-1] - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size_v) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size_v) - if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward( - self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output, - ) - else: - torch.ops.vllm.unified_attention_with_output( - query, - key, - value, - output, - self.layer_name, - ) - return output.view(-1, hidden_size) - else: - raise ValueError( - "Unsupport Error, currently only flash_diffkv_attn " - "backend with output buffer is supported" - ) - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # Block size may get updated after model loading, refresh it - block_size = vllm_config.cache_config.block_size - # Should not be called for enc-dec or encoder-only attention. - assert self.attn_type == AttentionType.DECODER - # Only support for full attention now. - assert self.sliding_window is None - return FullDiffkvAttentionSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - head_size_v=self.head_size_v, - dtype=self.kv_cache_torch_dtype, - ) - - class OpenPanguMLP(nn.Module): def __init__( self, @@ -673,7 +541,7 @@ class OpenPanguEmbeddedAttention(nn.Module): ) -class OpenPanguDiffkvAttention(nn.Module): +class OpenPanguSinkAttention(nn.Module): def __init__( self, config: PretrainedConfig, @@ -777,19 +645,19 @@ class OpenPanguDiffkvAttention(nn.Module): else: sliding_window = None - FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels) - self.attn = DiffkvAttention( + self.attn = StaticSinkAttention( self.num_heads, self.head_dim, - self.v_channels, self.scaling, + sink_len=self.param_sink_number, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashDiffkvAttentionBackend, + attn_backend=FlashAttentionBackend, + head_size_v=self.v_channels, ) if self.param_sink_number > 0: @@ -919,19 +787,13 @@ class OpenPanguDiffkvAttention(nn.Module): is_neox_style=is_neox_style, ) - def get_sink_kv(self) -> dict[str, torch.Tensor]: - if self.param_sink_number == 0: - raise ValueError("No sink_key and sink_value when param_sink_number == 0") - + def post_weight_load(self) -> None: if hasattr(self, "k_layernorm") and self.k_layernorm is not None: param_sink_key = self.k_layernorm(self.param_sink_key) else: param_sink_key = self.param_sink_key - return { - "sink_key": param_sink_key, - "sink_value": self.param_sink_value, - } + self.attn.update_sink_kv(param_sink_key, self.param_sink_value) class OpenPanguDecoderLayer(nn.Module): @@ -1001,7 +863,7 @@ class OpenPanguDecoderLayer(nn.Module): "rope_type": "default", "rope_theta": config.rope_theta, } - self.self_attn = OpenPanguDiffkvAttention( + self.self_attn = OpenPanguSinkAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -1359,8 +1221,17 @@ class OpenPanguModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + self.post_weight_load() return loaded_params + def post_weight_load(self) -> None: + for name, module in self.named_modules(): + if module is self: + continue + if hasattr(module, "post_weight_load"): + module.post_weight_load() + class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f5ad98cf2125c..8b030a04b438d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,6 +18,9 @@ from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, get_flash_attn_version, @@ -105,28 +108,48 @@ class FlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", + head_size_v: int | None = None, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + if head_size_v is None or head_size == head_size_v: + return (2, num_blocks, block_size, num_kv_heads, head_size) + else: + return ( + num_blocks, + block_size, + num_kv_heads, + head_size + head_size_v, + ) @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, + diff_kv: bool = False, ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (2, 0, 1, 3, 4, 5) + if not diff_kv: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + else: + # (num_blocks, num_layers, block_size, + # num_kv_heads, head_size + head_size_v) + return (0, 1, 2, 3, 4) elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) + stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3) elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 4, 0, 1, 3, 5) + if not diff_kv: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) + else: + # (num_blocks, num_kv_heads, num_layers, + # block_size, head_size + head_size_v) + return (2, 3, 0, 1, 4) elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) + stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order @@ -576,11 +599,14 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] + or [num_tokens, num_kv_heads, head_size_v] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] + or [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] + or [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values @@ -623,7 +649,13 @@ class FlashAttentionImpl(AttentionImpl): ) # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(0) + if self.head_size == kv_cache.shape[-1]: + # Same head_size for K and V + key_cache, value_cache = kv_cache.unbind(0) + else: + # Different head_size for K and V + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -640,16 +672,29 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if self.head_size == kv_cache.shape[-1]: + # kv_cache update for same head_size K and V + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # kv_cache update for different head_size K and V + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer diff --git a/vllm/v1/attention/backends/flash_diffkv_attn.py b/vllm/v1/attention/backends/flash_diffkv_attn.py deleted file mode 100644 index acd9cbcb4cabf..0000000000000 --- a/vllm/v1/attention/backends/flash_diffkv_attn.py +++ /dev/null @@ -1,1031 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" - -from typing import ClassVar - -import numpy as np -import torch - -from vllm import envs -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionType, - MultipleOf, - is_quantized_kv_cache, -) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs -from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) -from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available, -) - -if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_sinks, - flash_attn_varlen_func, - get_scheduler_metadata, - ) -from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config -from vllm.config.cache import CacheDType -from vllm.distributed.parallel_state import get_dcp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) -from vllm.platforms.interface import DeviceCapability -from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_dcp_local_seq_lens, - get_kv_cache_layout, -) -from vllm.v1.kv_cache_interface import AttentionSpec - -from .flash_attn import FlashAttentionMetadata - -logger = init_logger(__name__) - - -class FlashDiffkvAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # TODO: Remove hard code - head_size_v: int = 128 - - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - if ( - model_config - and model_config.is_hybrid - and ( - cache_config.mamba_ssm_cache_dtype == "float32" - or cache_config.mamba_cache_dtype == "float32" - ) - ): - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - return [16, 32, 64] - return [MultipleOf(16)] - - @staticmethod - def get_name() -> str: - return "FLASH_DIFFKV_ATTN" - - @classmethod - def supports_attn_type(cls, attn_type: str) -> bool: - """FlashDiffkvAttention supports all attention types.""" - from vllm.attention import AttentionType - - return attn_type in ( - AttentionType.DECODER, - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - AttentionType.ENCODER_DECODER, - ) - - @staticmethod - def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]: - return FlashDiffkvAttentionImpl - - @staticmethod - def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]: - return FlashDiffkvAttentionMetadataBuilder - - @classmethod - def set_head_size_v(cls, head_size_v: int) -> None: - cls.head_size_v = head_size_v - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return ( - num_blocks, - block_size, - num_kv_heads, - head_size + FlashDiffkvAttentionBackend.head_size_v, - ) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, block_size, num_kv_heads, head_size) - return (0, 1, 2, 3, 4) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 3, 0, 1, 4) - elif cache_layout == "HND": - stride_order = (0, 2, 1, 3) - else: - raise ValueError(f"Unknown cache layout format {cache_layout}.") - return stride_order - - @staticmethod - def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - return torch.float8_e4m3fn - else: - raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - - @classmethod - def supports_head_size(cls, head_size: int) -> bool: - return head_size % 8 == 0 and head_size <= 256 - - @classmethod - def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: - if kv_cache_dtype is None: - return True - if kv_cache_dtype.startswith("fp8"): - return flash_attn_supports_fp8() - return kv_cache_dtype in ["auto"] - - @classmethod - def supports_sink(cls) -> bool: - if not is_flash_attn_varlen_func_available(): - return False - return flash_attn_supports_sinks() - - @classmethod - def supports_compute_capability(cls, capability: DeviceCapability) -> bool: - return capability >= DeviceCapability(8, 0) - - @classmethod - def supports_combination( - cls, - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: CacheDType | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - device_capability: DeviceCapability, - ) -> str | None: - if has_sink and device_capability < DeviceCapability(9, 0): - return "sink not supported on compute capability < 9.0" - return None - - -def _get_sliding_window_configs( - vllm_config: VllmConfig, -) -> set[tuple[int, int] | None]: - """Get the set of all sliding window configs used in the model.""" - sliding_window_configs: set[tuple[int, int] | None] = set() - layers = get_layers_from_vllm_config(vllm_config, Attention) - for layer in layers.values(): - assert isinstance(layer.impl, FlashDiffkvAttentionImpl) - sliding_window_configs.add(layer.impl.sliding_window) - return sliding_window_configs - - -class FlashDiffkvAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata] -): - # FA3: - # Supports full cudagraphs for all cases. - # - # FA2: - # For FA2, a graph is captured with max_query_len=1, (which is what we - # capture by default for num_tokens <= max_num_seqs when there is no - # spec-decode) then these graphs will not work for mixed prefill-decode - # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling - # in FA2. - # In summary if we are running with spec decodes the graphs would - # work for mixed prefill-decode and uniform-decode. But for non-spec decodes - # the graphs would not work for mixed prefill-decode; sorta the inverse - # of UNIFORM_SINGLE_TOKEN_DECODE. - # There's probably a better way to describe this using `AttentionCGSupport` - # but for now just set it to `UNIFORM_BATCH` to get use to drop down - # to FULL_AND_PIECEWISE. - # TODO(luka, lucas): audit FA2 as part of: - # https://github.com/vllm-project/vllm/issues/22945 - _cudagraph_support = ( - AttentionCGSupport.ALWAYS - if get_flash_attn_version() == 3 - else AttentionCGSupport.UNIFORM_BATCH - ) - - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.model_config = vllm_config.model_config - self.parallel_config = vllm_config.parallel_config - self.cache_config = vllm_config.cache_config - self.compilation_config = vllm_config.compilation_config - - self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config - ) - self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) - self.kv_cache_dtype = kv_cache_spec.dtype - self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size - - self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = get_flash_attn_version() == 3 - - try: - from vllm.distributed.parallel_state import get_dcp_group - - self.dcp_world_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group - except AssertionError: - # DCP might not be initialized in testing - self.dcp_world_size = 1 - self.dcp_rank = 0 - - self.cp_kv_cache_interleave_size = ( - self.parallel_config.cp_kv_cache_interleave_size - ) - - self.use_full_cuda_graph = ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ) - self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size - - if self.use_full_cuda_graph and self.aot_schedule: - self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, - dtype=torch.int32, - device=self.device, - ) - # When using cuda graph, we need to set the upper bound of the - # number of splits so that large enough intermediate buffers are - # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - - # Sliding window size to be used with the AOT scheduler will be - # populated on first build() call. - self.aot_sliding_window: tuple[int, int] | None = None - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> FlashAttentionMetadata: - """ - fast_build disables AOT scheduling, used when there will be few - iterations i.e. spec-decode - """ - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - causal = common_attn_metadata.causal - - # the overhead of the aot schedule is not worth it for spec-decode - aot_schedule = self.aot_schedule and not fast_build - - if self.aot_sliding_window is None: - self.aot_sliding_window = (-1, -1) - # For the AOT scheduler we need the sliding window value to be - # constant for all layers to. We have to populate this on the first - # build() call so the layers are constructed (cannot populate) - # in __init__. - if aot_schedule: - sliding_window_configs = _get_sliding_window_configs(self.vllm_config) - if len(sliding_window_configs) == 1: - sliding_window_config = sliding_window_configs.pop() - if sliding_window_config is not None: - self.aot_sliding_window = sliding_window_config - elif len(sliding_window_configs) > 1: - self.aot_schedule = False - aot_schedule = False - - max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - - if vllm_is_batch_invariant(): - max_num_splits = 1 - - def schedule( - batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal - ): - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype - ) - else: - qkv_dtype = self.kv_cache_dtype - if aot_schedule: - return get_scheduler_metadata( - batch_size=batch_size, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q * self.dcp_world_size, - num_heads_kv=self.num_heads_kv, - headdim=self.headdim, - cache_seqlens=seqlens, - qkv_dtype=qkv_dtype, - cu_seqlens_q=cu_query_lens, - page_size=self.block_size, - causal=causal, - window_size=self.aot_sliding_window, - num_splits=max_num_splits, - ) - return None - - use_cascade = common_prefix_len > 0 - max_dcp_context_kv_len = 0 - dcp_context_kv_lens = None - - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None - - if self.dcp_world_size > 1: - query_kv_lens_cpu = ( - common_attn_metadata.query_start_loc_cpu[1:] - - common_attn_metadata.query_start_loc_cpu[:-1] - ) - dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu - - dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( - dcp_context_kv_lens_cpu, - self.dcp_world_size, - self.dcp_rank, - self.cp_kv_cache_interleave_size, - ) - dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) - max_dcp_context_kv_len = dcp_context_kv_lens.max().item() - - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=dcp_context_kv_lens, - max_seq_len=max_dcp_context_kv_len, - causal=False, - ) - elif use_cascade: - cu_prefix_query_lens = torch.tensor( - [0, num_actual_tokens], dtype=torch.int32, device=self.device - ) - prefix_kv_lens = torch.tensor( - [common_prefix_len], dtype=torch.int32, device=self.device - ) - suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True - ) - prefix_scheduler_metadata = schedule( - batch_size=1, - cu_query_lens=cu_prefix_query_lens, - max_query_len=num_actual_tokens, - seqlens=prefix_kv_lens, - max_seq_len=common_prefix_len, - causal=False, - ) - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - common_prefix_len, - causal=True, - ) - else: - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal, - ) - # For FA3 + full cudagraph - if self.use_full_cuda_graph and scheduler_metadata is not None: - n = scheduler_metadata.shape[0] - self.scheduler_metadata[:n] = scheduler_metadata - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 - scheduler_metadata = self.scheduler_metadata[:n] - - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - max_dcp_context_kv_len=max_dcp_context_kv_len, - dcp_context_kv_lens=dcp_context_kv_lens, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - scheduler_metadata=scheduler_metadata, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, - prefix_scheduler_metadata=prefix_scheduler_metadata, - max_num_splits=max_num_splits, - causal=causal, - ) - return attn_metadata - - def use_cascade_attention(self, *args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) - - -class FlashDiffkvAttentionImpl(AttentionImpl): - can_return_lse_for_decode: bool = True - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None = None, - attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - sinks: torch.Tensor | None = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - elif attn_type == AttentionType.ENCODER_ONLY: - self.sliding_window = (sliding_window - 1, sliding_window - 1) - else: - self.sliding_window = (sliding_window - 1, 0) - self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.attn_type = attn_type - self.vllm_flash_attn_version = get_flash_attn_version() - # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() - - if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device." - ) - - self.sinks = sinks - if self.sinks is not None: - assert flash_attn_supports_sinks(), ( - "Sinks are only supported in FlashAttention 3" - ) - assert self.sinks.shape[0] == num_heads, ( - "Sinks must have the same number of heads as the number of " - "heads in the layer" - ) - - def supports_quant_query_input(self) -> bool: - return True - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size_v] - kv_cache: shape = - [num_blocks, block_size, num_kv_heads, head_size + head_size_v] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size_v] - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for" - "FlashDiffkvAttentionImpl" - ) - - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - - attn_type = self.attn_type - - # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. - # Minimize the PyTorch ops in this method as much as possible. - # Whenever making a change in this method, please benchmark the - # performance to make sure it does not introduce any overhead. - - num_actual_tokens = attn_metadata.num_actual_tokens - - # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): - # For encoder attention, - # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention( - query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, - layer, - ) - - # For decoder and cross-attention, use KV cache as before - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] - - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - triton_reshape_and_cache_flash_diffkv( - key, - value, - kv_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.kv_cache_dtype.startswith("fp8"): - # queries are quantized in the attention layer - dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype - ) - key_cache = key_cache.view(dtype) - value_cache = value_cache.view(dtype) - - if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - scheduler_metadata = attn_metadata.scheduler_metadata - - descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - - if self.dcp_world_size > 1: - self._forward_with_dcp( - query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - key_cache, - value_cache, - output[:num_actual_tokens], - attn_metadata, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - return output - else: - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) - return output - - # Cascade attention (rare case). - cascade_attention( - output[:num_actual_tokens], - query[:num_actual_tokens], - key_cache, - value_cache, - cu_query_lens=attn_metadata.query_start_loc, - max_query_len=attn_metadata.max_query_len, - cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, - prefix_kv_lens=attn_metadata.prefix_kv_lens, - suffix_kv_lens=attn_metadata.suffix_kv_lens, - max_kv_len=attn_metadata.max_seq_len, - softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window, - logits_soft_cap=self.logits_soft_cap, - block_table=attn_metadata.block_table, - common_prefix_len=attn_metadata.common_prefix_len, - max_num_splits=attn_metadata.max_num_splits, - fa_version=self.vllm_flash_attn_version, - prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, - suffix_scheduler_metadata=attn_metadata.scheduler_metadata, - q_descale=layer._q_scale, - k_descale=layer._k_scale, - v_descale=layer._v_scale, - s_aux=self.sinks, - ) - return output - - def _forward_with_dcp( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - ) -> torch.Tensor: - cu_seqlens_q = attn_metadata.query_start_loc - max_seqlen_q = attn_metadata.max_query_len - block_table = attn_metadata.block_table - - query = query.contiguous() - query_across_dcp = get_dcp_group().all_gather(query, dim=1) - context_attn_out, context_lse = flash_attn_varlen_func( - q=query_across_dcp, - k=key_cache, - v=value_cache, - out=None, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=attn_metadata.dcp_context_kv_lens, - max_seqlen_k=attn_metadata.max_dcp_context_kv_len, - softmax_scale=self.scale, - causal=False, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=attn_metadata.scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] - context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( - context_attn_out, - context_lse.transpose(0, 1), - get_dcp_group(), - return_lse=True, - ) - context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() - - query_attn_out, query_lse = flash_attn_varlen_func( - q=query, - k=key, - v=value, - out=None, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - cu_seqlens_k=cu_seqlens_q, - max_seqlen_k=max_seqlen_q, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - softcap=self.logits_soft_cap, - return_softmax_lse=True, - fa_version=self.vllm_flash_attn_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - assert context_attn_out_cor.shape == query_attn_out.shape - assert context_lse_cor.shape == query_lse.shape - merge_attn_states( - output, - context_attn_out_cor, - context_lse_cor, - query_attn_out, - query_lse, - ) - - def _forward_encoder_attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - layer: torch.nn.Module, - ) -> torch.Tensor: - """Forward pass for encoder attention without KV cache. - - Args: - query: shape = [num_encoder_tokens, num_heads, head_size] - key: shape = [num_encoder_tokens, num_kv_heads, head_size] - value: shape = [num_encoder_tokens, num_kv_heads, head_size] - output: shape = [num_encoder_tokens, num_heads, head_size] - attn_metadata: Encoder attention metadata - layer: The attention layer - """ - # For encoder attention, process FP8 quantization if needed - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError( - "quantization is not supported for encoder attention" - ) - - # Use encoder-specific metadata for sequence information - cu_seqlens_q = attn_metadata.query_start_loc - cu_seqlens_k = attn_metadata.query_start_loc - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_query_len - - descale_shape = ( - cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads, - ) - - # Call flash attention directly on Q, K, V tensors - flash_attn_varlen_func( - q=query, - k=key, - v=value, - out=output, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=False, # Encoder attention is bidirectional - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=1 if self.batch_invariant_enabled else 0, - ) - - return output - - -def use_cascade_attention( - common_prefix_len: int, - query_lens: np.ndarray, - num_query_heads: int, - num_kv_heads: int, - use_alibi: bool, - use_sliding_window: bool, - use_local_attention: bool, - num_sms: int, - dcp_world_size: int, -) -> bool: - """Decide whether to use cascade attention. - - This function 1) checks whether cascade attention is supported with the - given configuration, and 2) heuristically decides whether using cascade - attention can improve performance. - """ - # Too short common prefix. Probably not worth using cascade attention. - # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. - # NOTE(woosuk): This is the common case. We should return False as soon as - # possible to avoid any unnecessary computation. - if common_prefix_len < 256: - return False - # Cascade attention is currently not supported with these variants. - if use_alibi or use_sliding_window or use_local_attention: - return False - # Too few queries. Probably not worth using cascade attention. - # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. - num_reqs = len(query_lens) - if num_reqs < 8: - return False - # disable cascade attention for DCP - if dcp_world_size > 1: - return False - - # Heuristics to decide whether using cascade attention is beneficial. - # 1. When FlashDecoding is not used for normal attention, cascade attention - # is likely to be faster since it saves memory bandwidth. - num_queries_per_kv = num_query_heads // num_kv_heads - # The criteria for using FlashDecoding can be found in the following link: - # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = ( - num_queries_per_kv > 1 - and not use_sliding_window - and not use_alibi - and np.all(query_lens == 1) - ) - if not use_flash_decoding: - # Use cascade attention. - return True - - # 2. When FlashDecoding is used for normal attention, it is not clear - # whether cascade attention is beneficial, because FlashDecoding can - # launch more CTAs than cascade attention. - # We use a simple performance model to compare the two methods. - # NOTE(woosuk): The performance model is very rough and may not be - # accurate. - num_tokens = num_reqs - # NOTE(woosuk): These are default tile sizes. flash-attn might use - # different tile sizes (e.g., 64 or 256) depending on the configuration. - q_tile_size = 128 - kv_tile_size = 128 - num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) - - cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) - cascade_waves = cdiv(cascade_ctas, num_sms) - cascade_time = cascade_waves * num_prefix_tiles - - flash_decoding_ctas = ( - num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) - ) - flash_decoding_ctas *= num_prefix_tiles - flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) - - # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time - - -def cascade_attention( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cu_query_lens: torch.Tensor, - max_query_len: int, - cu_prefix_query_lens: torch.Tensor, - prefix_kv_lens: torch.Tensor, - suffix_kv_lens: torch.Tensor, - max_kv_len: int, - softmax_scale: float, - alibi_slopes: torch.Tensor | None, - sliding_window: tuple[int, int], - logits_soft_cap: float, - block_table: torch.Tensor, - common_prefix_len: int, - max_num_splits: int, - fa_version: int, - prefix_scheduler_metadata: torch.Tensor | None = None, - suffix_scheduler_metadata: torch.Tensor | None = None, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - s_aux: torch.Tensor | None = None, -) -> torch.Tensor: - assert alibi_slopes is None, "Cascade attention does not support ALiBi." - # TODO: Support sliding window. - assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window." - ) - - num_tokens = query.shape[0] - block_size = key_cache.shape[-3] - assert common_prefix_len % block_size == 0 - num_common_kv_blocks = common_prefix_len // block_size - assert num_common_kv_blocks > 0 - descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) - - # Process shared prefix. - prefix_output, prefix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_prefix_query_lens, - seqused_k=prefix_kv_lens, - max_seqlen_q=num_tokens, - max_seqlen_k=common_prefix_len, - softmax_scale=softmax_scale, - causal=False, - window_size=sliding_window, - block_table=block_table[:1], - softcap=logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=prefix_scheduler_metadata, - fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - # s_aux is incorporated into prefix_lse inside the GPU kernel, - # enabling its effect during the final attention merge. - s_aux=s_aux, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, - ) - - descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) - - # Process suffix per query. - suffix_output, suffix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - seqused_k=suffix_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len - common_prefix_len, - softmax_scale=softmax_scale, - causal=True, - window_size=sliding_window, - block_table=block_table[:, num_common_kv_blocks:], - softcap=logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=suffix_scheduler_metadata, - fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, - ) - - # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f23b5564743c9..a9ce6e63cc775 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -210,11 +210,6 @@ class Scheduler(SchedulerInterface): hash_block_size=self.block_size, metrics_collector=self.kv_metrics_collector, ) - sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0) - if sink_len > 0: - assert sink_len % self.block_size == 0 - num_sink_block = sink_len // self.block_size - self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8444ee5ef425f..e6f65da36e413 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,10 +12,10 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, - FullDiffkvAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, + SinkFullAttentionSpec, SlidingWindowSpec, ) from vllm.v1.request import Request @@ -317,8 +317,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( - kv_cache_spec, - FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec, + kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -785,14 +784,35 @@ class CrossAttentionManager(SingleTypeKVCacheManager): raise NotImplementedError("CrossAttentionManager does not support caching") +class SinkFullAttentionManager(FullAttentionManager): + def __init__( + self, + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, + kv_cache_group_id: int, + dcp_world_size: int = 1, + pcp_world_size: int = 1, + ): + super().__init__( + kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size + ) + sink_len = kv_cache_spec.sink_len + if sink_len > 0: + assert sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.sink_blocks = self.block_pool.free_block_queue.popleft_n( + num_sink_block + ) + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, - FullDiffkvAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, + SinkFullAttentionSpec: SinkFullAttentionManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 1b130300b2218..656f5e7b81f55 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -80,6 +80,7 @@ class AttentionSpec(KVCacheSpec): @dataclass(frozen=True) class FullAttentionSpec(AttentionSpec): + head_size_v: int | None = None sliding_window: int | None = None attention_chunk_size: int | None = None """ @@ -92,6 +93,10 @@ class FullAttentionSpec(AttentionSpec): Default to None for not using sliding window attention. """ + def __post_init__(self): + if self.head_size_v is None: + object.__setattr__(self, "head_size_v", self.head_size) + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -124,88 +129,6 @@ class FullAttentionSpec(AttentionSpec): "All attention layers in the same KV cache group must be FullAttentionSpec." ) - sliding_window = set( - spec.sliding_window for spec in specs if spec.sliding_window is not None - ) - attention_chunk_size = set( - spec.attention_chunk_size - for spec in specs - if spec.attention_chunk_size is not None - ) - assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" - ) - merged_spec = cls( - block_size=specs[0].block_size, - num_kv_heads=specs[0].num_kv_heads, - head_size=specs[0].head_size, - dtype=specs[0].dtype, - sliding_window=cls.merge_window_sizes(sliding_window), - attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), - ) - for spec in specs: - for f in fields(AttentionSpec): - assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( - "All attention layers in the same KV cache group must have " - "the same attention spec." - ) - assert (merged_spec.sliding_window is not None) + ( - merged_spec.attention_chunk_size is not None - ) <= 1, ( - "Model with both sliding window layers and chunked local attention " - "layers is not supported." - ) - return merged_spec - - -@dataclass(frozen=True) -class FullDiffkvAttentionSpec(AttentionSpec): - head_size_v: int - sliding_window: int | None = None - attention_chunk_size: int | None = None - - """ - When hybrid allocator is disabled and the model contains both full - attention layers and sliding window attention layers, sliding - window attention are regarded as full attention in KV cache manager - (blocks are allocated for all tokens), while computed as sliding window - attention in model runner. - In this case, we use FullDiffkvAttentionSpec and record the sliding window size. - Default to None for not using sliding window attention. - """ - - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size - # Note(hc): each dcp rank only need save - # (max_model_len//dcp_world_size) tokens locally. - if dcp_world_size > 1: - max_model_len = cdiv(max_model_len, dcp_world_size) - return cdiv(max_model_len, self.block_size) * self.page_size_bytes - - @classmethod - def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: - if len(window_sizes) == 0: - return None - elif len(window_sizes) == 1: - return window_sizes.pop() - else: - raise ValueError( - "All attention layers in the same KV cache group must have the " - "same window size." - ) - - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of FullDiffkvAttentionSpec objects into a single - FullDiffkvAttentionSpec object. - """ - assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "FullDiffkvAttentionSpec." - ) - sliding_window = set( spec.sliding_window for spec in specs if spec.sliding_window is not None ) @@ -376,6 +299,56 @@ class CrossAttentionSpec(AttentionSpec): return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes +@dataclass(forzen=True) +class SinkFullAttentionSpec(FullAttentionSpec): + sink_len: int | None = None + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) + + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, + sink_len=specs[0].sink_len, + dtype=specs[0].dtype, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) + return merged_spec + + @dataclass(frozen=True) class UniformTypeKVCacheSpecs(KVCacheSpec): """ diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 37ec0fb97e06b..dd61d2150a797 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -263,6 +263,7 @@ class MultiGroupBlockTable: kernel_block_sizes: list[int], num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, + sink_len: int = 0, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -292,7 +293,7 @@ class MultiGroupBlockTable: block_size, max_num_reqs, max( - cdiv(max_model_len, block_size * total_cp_world_size), + cdiv(max_model_len + sink_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, ), max_num_batched_tokens, diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 6386f1a08b446..09a5bd885309d 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,16 +101,25 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, + **kwargs, ) # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + **stride_kwargs + ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c2b4b0dac3033..c567fc7219c3b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -143,7 +143,7 @@ class InputBatch: # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len + sink_len, + max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, @@ -151,6 +151,7 @@ class InputBatch: kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, + sink_len=sink_len, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ebdb9daf7fae9..5f81f9ba23ddc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,9 +27,6 @@ from vllm.attention.backends.abstract import ( MultipleOf, ) from vllm.attention.layer import Attention, MLAAttention -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -5209,16 +5206,25 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, + **kwargs, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + **stride_kwargs + ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) @@ -5410,7 +5416,6 @@ class GPUModelRunner( kv_caches = self.initialize_kv_cache_tensors( kv_cache_config, kernel_block_sizes ) - self.prepare_sink_kv_cache(kv_caches) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -5501,36 +5506,3 @@ class GPUModelRunner( self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() - - def prepare_sink_kv_cache(self, kv_caches) -> None: - if self.sink_len == 0: - return - - def find_module_by_name(model, target_name: str): - for name, module in model.named_modules(): - if name == target_name: - return module - raise KeyError(f"Module '{target_name}' not found") - - for layer_name, kv_cache in kv_caches.items(): - layer_prefix = layer_name.rsplit(".", 1)[0] - self_attn_module = find_module_by_name(self.model, layer_prefix) - if not hasattr(self_attn_module, "get_sink_kv"): - continue - else: - sink_kv = self_attn_module.get_sink_kv() - sink_kv_slot_mapping = torch.arange( - self.vllm_config.cache_config.block_size, - self.sink_len + self.vllm_config.cache_config.block_size, - device=torch.cuda.current_device(), - dtype=torch.long, - ) - triton_reshape_and_cache_flash_diffkv( - sink_kv["sink_key"], - sink_kv["sink_value"], - kv_cache, - sink_kv_slot_mapping, - self_attn_module.attn.kv_cache_dtype, - self_attn_module.attn._k_scale, - self_attn_module.attn._v_scale, - ) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 2bcc87b63bcdf..70e99db9e9762 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,17 +190,25 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( 1234, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, + **kwargs, ) try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True + include_num_layers_dimension=True, + **stride_kwargs, ) except (AttributeError, NotImplementedError): return False @@ -257,12 +265,19 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, + **kwargs, ) # prepend a num_layers dimension into the shape @@ -270,7 +285,8 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True + include_num_layers_dimension=True, + **stride_kwargs, ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): From 2470548aaa8b92a1cdd5ac446948e17b0e5a85b5 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 13 Dec 2025 16:25:38 +0800 Subject: [PATCH 08/16] Fix typos Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/core/single_type_kv_cache_manager.py | 2 +- vllm/v1/kv_cache_interface.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e6f65da36e413..4eeea533464c9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -787,7 +787,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): class SinkFullAttentionManager(FullAttentionManager): def __init__( self, - kv_cache_spec: KVCacheSpec, + kv_cache_spec: SinkFullAttentionSpec, block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 656f5e7b81f55..c0ab66f7081f7 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -299,7 +299,7 @@ class CrossAttentionSpec(AttentionSpec): return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes -@dataclass(forzen=True) +@dataclass(frozen=True) class SinkFullAttentionSpec(FullAttentionSpec): sink_len: int | None = None From a0563e7368e25663345780b31bb23a60d828ca09 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 13 Dec 2025 16:32:56 +0800 Subject: [PATCH 09/16] Add assert in init of SinkFullAttentionManager Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/core/single_type_kv_cache_manager.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 4eeea533464c9..14905b36754b4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -797,12 +797,9 @@ class SinkFullAttentionManager(FullAttentionManager): kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size ) sink_len = kv_cache_spec.sink_len - if sink_len > 0: - assert sink_len % self.block_size == 0 - num_sink_block = sink_len // self.block_size - self.sink_blocks = self.block_pool.free_block_queue.popleft_n( - num_sink_block - ) + assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { From a7430ab479ff526f198de7008c0819e5b461babe Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 13 Dec 2025 17:59:28 +0800 Subject: [PATCH 10/16] Fix typos Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/layers/static_sink_attention.py | 5 +---- vllm/attention/ops/triton_reshape_and_cache_flash.py | 4 +--- vllm/model_executor/models/openpangu.py | 2 +- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index 7687651ee682b..beb9add10024b 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -66,9 +66,6 @@ def create_static_sink_attention_backend( common_attn_metadata.seq_lens[:] = ( common_attn_metadata.seq_lens + self.sink_len ) - common_attn_metadata.seq_lens_cpu = ( - common_attn_metadata.seq_lens_cpu + self.sink_len - ) common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) @@ -152,7 +149,7 @@ class StaticSinkAttention(Attention): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - output_shape: torch.size | None = None, + output_shape: torch.Size | None = None, ) -> torch.Tensor: assert self.sink_key is not None and self.sink_value is not None, ( "sink_key and sink_value have not been prepared" diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index c119033896ec6..a383de0ac76cc 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -238,7 +238,7 @@ def reshape_and_cache_kernel_flash_diffkv( # [TILE_SIZE] value_load = tl.load( - value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v + value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v ) if FP8_KV_CACHE: if value_load.dtype.is_fp8(): @@ -322,8 +322,6 @@ def triton_reshape_and_cache_flash_diffkv( else: # cuda num_stages = 10 num_warps = 16 - if torch.cuda.get_device_capability(key.device)[0] < 9: - TILE_SIZE = min(512, TILE_SIZE) # TODO(ngl): maybe replace with static launch grid to avoid overhead if # using cudagraphs diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 8e4bb62e137a8..43bfa4f8324cc 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -778,10 +778,10 @@ class OpenPanguSinkAttention(nn.Module): quant_config: QuantizationConfig | None, ) -> None: is_neox_style = False + rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim} self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.qk_rope_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=is_neox_style, From 93a7afcab3b81a84e2b7c82cd16cc499e6b9a903 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 17 Dec 2025 10:34:14 +0800 Subject: [PATCH 11/16] Exteng SinkFullAttentionManager to handle sink blocks management, avoid modifying blk_table_tensor during the build of attn_metadata Signed-off-by: yuantao <2422264527@qq.com> --- .../attention/layers/static_sink_attention.py | 13 +- vllm/v1/core/single_type_kv_cache_manager.py | 140 ++++++++++++++++++ vllm/v1/worker/block_table.py | 11 +- vllm/v1/worker/gpu/attn_utils.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 5 +- .../worker/kv_connector_model_runner_mixin.py | 10 +- 6 files changed, 167 insertions(+), 17 deletions(-) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index beb9add10024b..2bf95943ee095 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -66,20 +66,13 @@ def create_static_sink_attention_backend( common_attn_metadata.seq_lens[:] = ( common_attn_metadata.seq_lens + self.sink_len ) + common_attn_metadata.seq_lens[ + common_attn_metadata.seq_lens == self.sink_len + ] = 0 common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) - blk_table_tensor = common_attn_metadata.block_table_tensor - sink_block_table = self.sink_block_table[None, :].expand( - blk_table_tensor.shape[0], -1 - ) - blk_table_tensor_clone = blk_table_tensor.clone() - blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[ - :, : -self.num_sink_blocks - ] - blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table - return super().build(common_prefix_len, common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 14905b36754b4..fe9e7a9891941 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -801,6 +801,146 @@ class SinkFullAttentionManager(FullAttentionManager): num_sink_block = sink_len // self.block_size self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], + ) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) + # Number of sink blocks is calculated into num_new_blocks + if len(self.req_to_blocks[request_id]) > 0: + num_new_blocks = num_new_blocks + len(self.sink_blocks) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum( + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) + return num_new_blocks + num_evictable_computed_blocks + + def save_new_computed_blocks( + self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] + ) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + if request_id not in self.num_cached_block: + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + # Append both sink blocks and hitted prefix cache blocks + req_blocks.extend(self.sink_blocks + new_computed_blocks) + self.num_cached_block[request_id] = len(new_computed_blocks) + else: + # A running request. Should not have new computed blocks. + assert len(new_computed_blocks) == 0 + + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + # For existing requests, number of sink blocks is calculated into + # num_new_blocks + if len(req_blocks) > 0: + num_new_blocks = num_new_blocks + len(self.sink_blocks) + if num_new_blocks <= 0: + return [] + else: + new_blocks = self.block_pool.get_new_blocks(num_new_blocks) + # For new requests, allocate sink blocks + if len(req_blocks) == 0: + req_blocks.extend(self.sink_blocks + new_blocks) + else: + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + num_cached_blocks = self.num_cached_block.get(request.request_id, 0) + num_full_blocks = num_tokens // self.block_size + + if num_cached_blocks >= num_full_blocks: + return + + self.block_pool.cache_full_blocks( + request=request, + # Do not cache sink blocks + blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :], + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + kv_cache_group_id=self.kv_cache_group_id, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + """ + Free the blocks for the request. + + Args: + request_id: The request ID. + """ + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + # Do not free sink blocks + if len(req_blocks) > 0: + req_blocks = req_blocks[len(self.sink_blocks) :] + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index dd61d2150a797..5703f80db0754 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -23,6 +23,7 @@ class BlockTable: device: torch.device, kernel_block_size: int, cp_kv_cache_interleave_size: int, + sink_len: int = 0, ): """ Args: @@ -63,6 +64,8 @@ class BlockTable: self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + self.sink_block_len = sink_len // self.block_size + self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 @@ -151,7 +154,7 @@ class BlockTable: block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size - ) + ) + self.sink_block_len block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local @@ -177,9 +180,10 @@ class BlockTable: mask, slot_mapping, -1 ) else: + # When self.sink_block_len > 0, we need to shift block table indices block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size - ) + ) + self.sink_block_len block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size @@ -293,7 +297,7 @@ class MultiGroupBlockTable: block_size, max_num_reqs, max( - cdiv(max_model_len + sink_len, block_size * total_cp_world_size), + cdiv(max_model_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, ), max_num_batched_tokens, @@ -301,6 +305,7 @@ class MultiGroupBlockTable: device, kernel_block_size, cp_kv_cache_interleave_size, + sink_len=sink_len, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 09a5bd885309d..6845652688430 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,7 +101,10 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5f81f9ba23ddc..0bf773c84596d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5206,7 +5206,10 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 70e99db9e9762..a31013502c377 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,7 +190,10 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: @@ -265,7 +268,10 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend - if hasattr(kv_cache_spec, "head_size_v"): + if ( + getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) + != kv_cache_spec.head_size + ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} else: From 378e20833f9a200fc5c78cf0a21fe486ef78ffc9 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Wed, 17 Dec 2025 11:03:13 +0800 Subject: [PATCH 12/16] Fix pre-commit Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/core/single_type_kv_cache_manager.py | 3 ++- vllm/v1/worker/gpu/attn_utils.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/kv_connector_model_runner_mixin.py | 8 ++++---- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index fe9e7a9891941..e6c7150500cc3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -855,7 +855,8 @@ class SinkFullAttentionManager(FullAttentionManager): req_blocks = self.req_to_blocks[request_id] assert len(req_blocks) == 0 # Append both sink blocks and hitted prefix cache blocks - req_blocks.extend(self.sink_blocks + new_computed_blocks) + req_blocks.extend(self.sink_blocks) + req_blocks.extend(new_computed_blocks) self.num_cached_block[request_id] = len(new_computed_blocks) else: # A running request. Should not have new computed blocks. diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 6845652688430..604edf3d7ae29 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -102,8 +102,8 @@ def _reshape_kv_cache( attn_backend = attn_backends[layer_name] if ( - getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) - != kv_cache_spec.head_size + hasattr(kv_cache_spec, "head_size_v") + and kv_cache_spec.head_size_v != kv_cache_spec.head_size ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0bf773c84596d..e02e7ee2d4c26 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5207,8 +5207,8 @@ class GPUModelRunner( kernel_num_blocks = num_blocks * num_blocks_per_kv_block if ( - getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) - != kv_cache_spec.head_size + hasattr(kv_cache_spec, "head_size_v") + and kv_cache_spec.head_size_v != kv_cache_spec.head_size ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a31013502c377..a45b83dc788f1 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -191,8 +191,8 @@ class KVConnectorModelRunnerMixin: attn_backend = attn_group.backend if ( - getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) - != kv_cache_spec.head_size + hasattr(kv_cache_spec, "head_size_v") + and kv_cache_spec.head_size_v != kv_cache_spec.head_size ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} @@ -269,8 +269,8 @@ class KVConnectorModelRunnerMixin: attn_backend = attn_group.backend if ( - getattr(kv_cache_spec, "head_size_v", kv_cache_spec.head_size) - != kv_cache_spec.head_size + hasattr(kv_cache_spec, "head_size_v") + and kv_cache_spec.head_size_v != kv_cache_spec.head_size ): kwargs = {"head_size_v": kv_cache_spec.head_size_v} stride_kwargs = {"diff_kv": True} From 4ad3f758753084db24a30ce3fbdb8b457c8ee708 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Mon, 22 Dec 2025 22:45:11 +0800 Subject: [PATCH 13/16] Refactor code, add FLASH_ATTN_DIFFKV backend Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 3 + vllm/model_executor/models/openpangu.py | 14 +- vllm/v1/attention/backends/flash_attn.py | 81 ++---- .../attention/backends/flash_attn_diffkv.py | 269 ++++++++++++++++++ 4 files changed, 299 insertions(+), 68 deletions(-) create mode 100644 vllm/v1/attention/backends/flash_attn_diffkv.py diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index eaa0fa1d5db39..8461ed61480b7 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + FLASH_ATTN_DIFFKV = ( + "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend" + ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 43bfa4f8324cc..662ecef3ac8f6 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -79,9 +79,10 @@ from vllm.model_executor.models.utils import ( sequence_parallel_chunk, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend def check_ffn_act_fn(act_fn: str): @@ -645,6 +646,7 @@ class OpenPanguSinkAttention(nn.Module): else: sliding_window = None + FlashAttentionDiffKVBackend.set_head_size_v(self.v_channels) self.attn = StaticSinkAttention( self.num_heads, self.head_dim, @@ -656,7 +658,7 @@ class OpenPanguSinkAttention(nn.Module): per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashAttentionBackend, + attn_backend=FlashAttentionDiffKVBackend, head_size_v=self.v_channels, ) @@ -668,7 +670,7 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.head_dim, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) ) @@ -688,7 +690,7 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.v_channels, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) ) @@ -706,9 +708,11 @@ class OpenPanguSinkAttention(nn.Module): self.num_kv_heads, self.v_channels, ), - device=torch.cuda.current_device(), + device=current_platform.current_device(), dtype=config.torch_dtype, ) + # To enable dummy run with out weight + self.post_weight_load() def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): output_dim = getattr(param, "output_dim", None) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8b030a04b438d..f5ad98cf2125c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,9 +18,6 @@ from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, get_flash_attn_version, @@ -108,48 +105,28 @@ class FlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - head_size_v: int | None = None, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - if head_size_v is None or head_size == head_size_v: - return (2, num_blocks, block_size, num_kv_heads, head_size) - else: - return ( - num_blocks, - block_size, - num_kv_heads, - head_size + head_size_v, - ) + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, - diff_kv: bool = False, ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() if cache_layout == "NHD" and include_num_layers_dimension: - if not diff_kv: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (2, 0, 1, 3, 4, 5) - else: - # (num_blocks, num_layers, block_size, - # num_kv_heads, head_size + head_size_v) - return (0, 1, 2, 3, 4) + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3) + stride_order = (0, 1, 2, 3, 4) elif cache_layout == "HND" and include_num_layers_dimension: - if not diff_kv: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 4, 0, 1, 3, 5) - else: - # (num_blocks, num_kv_heads, num_layers, - # block_size, head_size + head_size_v) - return (2, 3, 0, 1, 4) + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3) + stride_order = (0, 1, 3, 2, 4) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order @@ -599,14 +576,11 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - or [num_tokens, num_kv_heads, head_size_v] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] - or [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] - or [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values @@ -649,13 +623,7 @@ class FlashAttentionImpl(AttentionImpl): ) # For decoder and cross-attention, use KV cache as before - if self.head_size == kv_cache.shape[-1]: - # Same head_size for K and V - key_cache, value_cache = kv_cache.unbind(0) - else: - # Different head_size for K and V - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] + key_cache, value_cache = kv_cache.unbind(0) # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -672,29 +640,16 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - if self.head_size == kv_cache.shape[-1]: - # kv_cache update for same head_size K and V - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - # kv_cache update for different head_size K and V - triton_reshape_and_cache_flash_diffkv( - key, - value, - kv_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py new file mode 100644 index 0000000000000..2e36740bd9e52 --- /dev/null +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import get_kv_cache_layout + +from .flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + cascade_attention, +) + +logger = init_logger(__name__) + + +class FlashAttentionDiffKVBackend(FlashAttentionBackend): + # Default to 128 for this backend + head_size_v: int = 128 + + @classmethod + def set_head_size_v(cls, head_size_v: int) -> None: + cls.head_size_v = head_size_v + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_DIFFKV" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionDiffKVImpl + + # Do not modify the interface of get_kv_cache_shape, + # but consider head_size_v when returning result. + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return ( + num_blocks, + block_size, + num_kv_heads, + head_size + FlashAttentionDiffKVBackend.head_size_v, + ) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, block_size, + # num_kv_heads, head_size + head_size_v) + return (1, 0, 2, 3, 4) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, + # block_size, head_size + head_size_v) + return (1, 3, 0, 2, 4) + elif cache_layout == "HND": + stride_order = (0, 2, 1, 3) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + +class FlashAttentionDiffKVImpl(FlashAttentionImpl): + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size_v] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads, head_size + head_size_v] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size_v] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlashAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before + # Different head_size for K and V + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + + # kv_cache update for different head_size K and V + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + ) + return output From cf58a620990c3bd1ae45cccc90575b77351f3637 Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Mon, 22 Dec 2025 22:48:14 +0800 Subject: [PATCH 14/16] Refactor code, move static sink logics to builder Signed-off-by: yuantao <2422264527@qq.com> --- .../attention/layers/static_sink_attention.py | 47 +++++- vllm/v1/core/single_type_kv_cache_manager.py | 141 ------------------ vllm/v1/worker/block_table.py | 10 +- vllm/v1/worker/gpu/attn_utils.py | 14 +- vllm/v1/worker/gpu_input_batch.py | 2 - vllm/v1/worker/gpu_model_runner.py | 21 +-- .../worker/kv_connector_model_runner_mixin.py | 22 --- 7 files changed, 48 insertions(+), 209 deletions(-) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index 2bf95943ee095..e5ed16ec14932 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -17,6 +17,8 @@ from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -48,9 +50,23 @@ def create_static_sink_attention_backend( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) + model_config = vllm_config.model_config + scheduler_config = vllm_config.scheduler_config self.sink_len = sink_len + self.block_size = vllm_config.cache_config.block_size self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size - self.sink_block_table = torch.arange( + self.max_num_blocks = cdiv( + model_config.max_model_len, vllm_config.cache_config.block_size + ) + self.block_table_with_sink = torch.zeros( + ( + scheduler_config.max_num_seqs, + self.max_num_blocks + self.num_sink_blocks, + ), + device=device, + dtype=torch.int32, + ) + self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange( 1, self.num_sink_blocks + 1, device=device, @@ -72,6 +88,14 @@ def create_static_sink_attention_backend( common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) + max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size) + num_reqs = common_attn_metadata.num_reqs + self.block_table_with_sink[ + :num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks + ] = common_attn_metadata.block_table_tensor[:, :max_num_blocks] + common_attn_metadata.block_table_tensor = self.block_table_with_sink[ + :num_reqs + ] return super().build(common_prefix_len, common_attn_metadata, fast_build) @@ -84,7 +108,8 @@ def create_static_sink_attention_backend( return attn_backend -class StaticSinkAttention(Attention): +@CustomOp.register("static_sink_attention") +class StaticSinkAttention(Attention, CustomOp): """ Attention with static sink tokens """ @@ -118,7 +143,8 @@ class StaticSinkAttention(Attention): underlying_attn_backend, sink_len=sink_len, ) - super().__init__( + Attention.__init__( + self=self, num_heads=num_heads, head_size=head_size, scale=scale, @@ -126,6 +152,7 @@ class StaticSinkAttention(Attention): attn_backend=attn_backend, **kwargs, ) + CustomOp.__init__(self) self.sink_len = sink_len self.block_size = block_size @@ -137,7 +164,7 @@ class StaticSinkAttention(Attention): self.sink_key = sink_key self.sink_value = sink_value - def forward( + def forward_native( self, query: torch.Tensor, key: torch.Tensor, @@ -154,6 +181,18 @@ class StaticSinkAttention(Attention): return super().forward(query, key, value, output_shape) + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + return self.forward_native(query, key, value, output_shape) + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + def populate_sink_kv(self, self_kv_cache): sink_kv_slot_mapping = torch.arange( self.block_size, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e6c7150500cc3..14905b36754b4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -801,147 +801,6 @@ class SinkFullAttentionManager(FullAttentionManager): num_sink_block = sink_len // self.block_size self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) - def get_num_blocks_to_allocate( - self, - request_id: str, - num_tokens: int, - new_computed_blocks: Sequence[KVCacheBlock], - ) -> int: - """ - Get the number of blocks needed to be allocated for the request. - - Args: - request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including - tokens that are already allocated). - new_computed_blocks: The new computed blocks just hitting the - prefix caching. - - Returns: - The number of blocks. - """ - - num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = ( - num_required_blocks - - len(new_computed_blocks) - - len(self.req_to_blocks[request_id]) - ) - # Number of sink blocks is calculated into num_new_blocks - if len(self.req_to_blocks[request_id]) > 0: - num_new_blocks = num_new_blocks + len(self.sink_blocks) - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it will be changed from a free block - # to a computed block when the request is allocated, so we also count - # it as needed to be allocated. - num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks - ) - return num_new_blocks + num_evictable_computed_blocks - - def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] - ) -> None: - """ - Add the new computed blocks to the request. - - Args: - request_id: The request ID. - new_computed_blocks: The new computed blocks just hitting the - prefix cache. - """ - if request_id not in self.num_cached_block: - # A new request. - req_blocks = self.req_to_blocks[request_id] - assert len(req_blocks) == 0 - # Append both sink blocks and hitted prefix cache blocks - req_blocks.extend(self.sink_blocks) - req_blocks.extend(new_computed_blocks) - self.num_cached_block[request_id] = len(new_computed_blocks) - else: - # A running request. Should not have new computed blocks. - assert len(new_computed_blocks) == 0 - - def allocate_new_blocks( - self, request_id: str, num_tokens: int - ) -> list[KVCacheBlock]: - """ - Allocate new blocks for the request to give it at least `num_tokens` - token slots. - - Args: - request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including - tokens that are already allocated). - - Returns: - The new allocated blocks. - """ - req_blocks = self.req_to_blocks[request_id] - num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = num_required_blocks - len(req_blocks) - # For existing requests, number of sink blocks is calculated into - # num_new_blocks - if len(req_blocks) > 0: - num_new_blocks = num_new_blocks + len(self.sink_blocks) - if num_new_blocks <= 0: - return [] - else: - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - # For new requests, allocate sink blocks - if len(req_blocks) == 0: - req_blocks.extend(self.sink_blocks + new_blocks) - else: - req_blocks.extend(new_blocks) - return new_blocks - - def cache_blocks(self, request: Request, num_tokens: int) -> None: - """ - Cache the blocks for the request. - - Args: - request: The request. - num_tokens: The total number of tokens that need to be cached - (including tokens that are already cached). - """ - num_cached_blocks = self.num_cached_block.get(request.request_id, 0) - num_full_blocks = num_tokens // self.block_size - - if num_cached_blocks >= num_full_blocks: - return - - self.block_pool.cache_full_blocks( - request=request, - # Do not cache sink blocks - blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks, - block_size=self.block_size, - kv_cache_group_id=self.kv_cache_group_id, - ) - - self.num_cached_block[request.request_id] = num_full_blocks - - def free(self, request_id: str) -> None: - """ - Free the blocks for the request. - - Args: - request_id: The request ID. - """ - # Default to [] in case a request is freed (aborted) before alloc. - req_blocks = self.req_to_blocks.pop(request_id, []) - # Do not free sink blocks - if len(req_blocks) > 0: - req_blocks = req_blocks[len(self.sink_blocks) :] - - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(req_blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request_id, None) - spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5703f80db0754..37ec0fb97e06b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -23,7 +23,6 @@ class BlockTable: device: torch.device, kernel_block_size: int, cp_kv_cache_interleave_size: int, - sink_len: int = 0, ): """ Args: @@ -64,8 +63,6 @@ class BlockTable: self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block - self.sink_block_len = sink_len // self.block_size - self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 @@ -154,7 +151,7 @@ class BlockTable: block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size - ) + self.sink_block_len + ) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local @@ -180,10 +177,9 @@ class BlockTable: mask, slot_mapping, -1 ) else: - # When self.sink_block_len > 0, we need to shift block table indices block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size - ) + self.sink_block_len + ) block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size @@ -267,7 +263,6 @@ class MultiGroupBlockTable: kernel_block_sizes: list[int], num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, - sink_len: int = 0, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -305,7 +300,6 @@ class MultiGroupBlockTable: device, kernel_block_size, cp_kv_cache_interleave_size, - sink_len=sink_len, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 604edf3d7ae29..6386f1a08b446 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,28 +101,16 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - **kwargs, ) # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - **stride_kwargs - ) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c567fc7219c3b..ead7a3619dea5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -96,7 +96,6 @@ class InputBatch: is_pooling_model: bool = False, num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, - sink_len: int = 0, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -151,7 +150,6 @@ class InputBatch: kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, - sink_len=sink_len, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e02e7ee2d4c26..b1e4d04717768 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -332,10 +332,6 @@ class GPUModelRunner( self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.inputs_embeds_size = model_config.get_inputs_embeds_size() self.attention_chunk_size = model_config.attention_chunk_size - self.sink_len = getattr( - self.vllm_config.model_config.hf_config, "param_sink_number", 0 - ) - assert self.sink_len % self.cache_config.block_size == 0 # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = model_config.uses_alibi @@ -459,7 +455,6 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, - sink_len=self.sink_len, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -5079,7 +5074,7 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, - sink_len=self.sink_len, + # sink_len=self.sink_len, ) def _allocate_kv_cache_tensors( @@ -5206,28 +5201,16 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, - **kwargs, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - **stride_kwargs - ) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a45b83dc788f1..f266a1386e10d 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,28 +190,17 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( 1234, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, - **kwargs, ) try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=True, - **stride_kwargs, ) except (AttributeError, NotImplementedError): return False @@ -268,22 +257,12 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend - if ( - hasattr(kv_cache_spec, "head_size_v") - and kv_cache_spec.head_size_v != kv_cache_spec.head_size - ): - kwargs = {"head_size_v": kv_cache_spec.head_size_v} - stride_kwargs = {"diff_kv": True} - else: - kwargs = {} - stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, - **kwargs, ) # prepend a num_layers dimension into the shape @@ -292,7 +271,6 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( include_num_layers_dimension=True, - **stride_kwargs, ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): From 049d2aad0c765be8c0a1680644187c5f6e4a423b Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Mon, 22 Dec 2025 22:51:34 +0800 Subject: [PATCH 15/16] Fix minor difference Signed-off-by: yuantao <2422264527@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 1 - vllm/v1/worker/kv_connector_model_runner_mixin.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b1e4d04717768..978224faae65e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5074,7 +5074,6 @@ class GPUModelRunner( logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=self.num_spec_tokens, - # sink_len=self.sink_len, ) def _allocate_kv_cache_tensors( diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index f266a1386e10d..2bcc87b63bcdf 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -200,7 +200,7 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True, + include_num_layers_dimension=True ) except (AttributeError, NotImplementedError): return False @@ -270,7 +270,7 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True, + include_num_layers_dimension=True ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): From 1aff6fff5664799fb5fcf4e56b97ed032a056d33 Mon Sep 17 00:00:00 2001 From: yt0428 <51468697+yt0428@users.noreply.github.com> Date: Tue, 23 Dec 2025 20:38:02 +0800 Subject: [PATCH 16/16] Fix pre-commit Signed-off-by: yt0428 <51468697+yt0428@users.noreply.github.com> --- vllm/v1/kv_cache_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 359b5c520d44d..7a7bb9036211b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -88,7 +88,7 @@ class FullAttentionSpec(AttentionSpec): attention in model runner. In this case, we use FullAttentionSpec and record the sliding window size. """ - + head_size_v: int | None = None sliding_window: int | None = None