From b565203d926267a48266fd1a60bac80553bb80ff Mon Sep 17 00:00:00 2001 From: yuantao <2422264527@qq.com> Date: Sat, 13 Dec 2025 15:47:33 +0800 Subject: [PATCH] Refacotr code. Extent FLASH_ATTN to support different KV size and create a new StaticSinkAttention for sink token logics Signed-off-by: yuantao <2422264527@qq.com> --- vllm/attention/backends/registry.py | 3 - vllm/attention/layer.py | 11 +- .../attention/layers/static_sink_attention.py | 225 ++++ .../ops/triton_reshape_and_cache_flash.py | 90 +- vllm/model_executor/models/openpangu.py | 169 +-- vllm/v1/attention/backends/flash_attn.py | 81 +- .../attention/backends/flash_diffkv_attn.py | 1031 ----------------- vllm/v1/core/sched/scheduler.py | 5 - vllm/v1/core/single_type_kv_cache_manager.py | 28 +- vllm/v1/kv_cache_interface.py | 137 +-- vllm/v1/worker/block_table.py | 3 +- vllm/v1/worker/gpu/attn_utils.py | 11 +- vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 48 +- .../worker/kv_connector_model_runner_mixin.py | 20 +- 15 files changed, 503 insertions(+), 1362 deletions(-) create mode 100644 vllm/attention/layers/static_sink_attention.py delete mode 100644 vllm/v1/attention/backends/flash_diffkv_attn.py diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 630858fc2193a..eaa0fa1d5db39 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,9 +42,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - FLASH_DIFFKV_ATTN = ( - "vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend" - ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5c43dae35b812..6c4cc0085432c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -191,6 +191,7 @@ class Attention(nn.Module, AttentionLayerBase): attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + head_size_v: int | None = None, **extra_impl_args, ) -> None: """ @@ -232,6 +233,7 @@ class Attention(nn.Module, AttentionLayerBase): self.num_heads = num_heads self.head_size = head_size + self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None @@ -370,6 +372,10 @@ class Attention(nn.Module, AttentionLayerBase): query, _ = self.query_quant(query, self._q_scale) if self.use_output: + if output_shape is None: + output_shape = torch.Size( + (*query.shape[:-1], self.num_heads * self.head_size_v) + ) output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] @@ -377,11 +383,11 @@ class Attention(nn.Module, AttentionLayerBase): # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size_v) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -456,6 +462,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, + head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, ) diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py new file mode 100644 index 0000000000000..7687651ee682b --- /dev/null +++ b/vllm/attention/layers/static_sink_attention.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + SinkFullAttentionSpec, +) + +logger = init_logger(__name__) + + +@functools.lru_cache +def create_static_sink_attention_backend( + underlying_attn_backend: type[AttentionBackend], + sink_len: int = 0, +) -> type[AttentionBackend]: + prefix = "StaticSink_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class StaticSinkAttentionBuilder(underlying_builder): # type: ignore + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.sink_len = sink_len + self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size + self.sink_block_table = torch.arange( + 1, + self.num_sink_blocks + 1, + device=device, + dtype=torch.int32, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + common_attn_metadata.seq_lens[:] = ( + common_attn_metadata.seq_lens + self.sink_len + ) + common_attn_metadata.seq_lens_cpu = ( + common_attn_metadata.seq_lens_cpu + self.sink_len + ) + common_attn_metadata.max_seq_len = ( + common_attn_metadata.max_seq_len + self.sink_len + ) + + blk_table_tensor = common_attn_metadata.block_table_tensor + sink_block_table = self.sink_block_table[None, :].expand( + blk_table_tensor.shape[0], -1 + ) + blk_table_tensor_clone = blk_table_tensor.clone() + blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[ + :, : -self.num_sink_blocks + ] + blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table + + return super().build(common_prefix_len, common_attn_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=StaticSinkAttentionBuilder, + ) + + return attn_backend + + +class StaticSinkAttention(Attention): + """ + Attention with static sink tokens + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + sink_len: int, + attn_backend: type[AttentionBackend] | None = None, + cache_config: CacheConfig | None = None, + **kwargs, + ): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if attn_backend is not None: + underlying_attn_backend = attn_backend + else: + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_static_sink_attention_backend( + underlying_attn_backend, + sink_len=sink_len, + ) + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + **kwargs, + ) + + self.sink_len = sink_len + self.block_size = block_size + self.sink_populated = False + self.sink_key = None + self.sink_value = None + + def update_sink_kv(self, sink_key, sink_value) -> None: + self.sink_key = sink_key + self.sink_value = sink_value + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.size | None = None, + ) -> torch.Tensor: + assert self.sink_key is not None and self.sink_value is not None, ( + "sink_key and sink_value have not been prepared" + ) + if not self.sink_populated: + forward_context: ForwardContext = get_forward_context() + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) + + return super().forward(query, key, value, output_shape) + + def populate_sink_kv(self, self_kv_cache): + sink_kv_slot_mapping = torch.arange( + self.block_size, + self.sink_len + self.block_size, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + triton_reshape_and_cache_flash_diffkv( + self.sink_key, + self.sink_value, + self_kv_cache, + sink_kv_slot_mapping, + self.kv_cache_dtype, + self._k_scale, + self._v_scale, + ) + # We only populate the sink_key and sink_value once + self.sink_populated = True + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + + return SinkFullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + sink_len=self.sink_len, + dtype=self.kv_cache_torch_dtype, + ) + + +def maybe_populate_sink( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if self.sink_populated or self_kv_cache.numel() == 0: + return + self.populate_sink_kv(self_kv_cache) + + +def maybe_populate_sink_fake( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="maybe_populate_sink", + op_func=maybe_populate_sink, + mutates_args=["self_kv_cache"], + fake_impl=maybe_populate_sink_fake, +) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index d79e209e303b0..c119033896ec6 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -186,17 +186,20 @@ def triton_reshape_and_cache_flash( @triton.jit def reshape_and_cache_kernel_flash_diffkv( - kv_ptr, # [num_tokens, num_heads, head_size + head_size_v] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size_v] kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v] slot_mapping_ptr, # [num_tokens] k_scale, # float32 v_scale, # float32 # strides - kv_stride: tl.int64, + key_stride: tl.int64, + value_stride: tl.int64, block_stride: tl.int64, page_stride: tl.int64, num_heads: tl.constexpr, - head_size_kv: tl.constexpr, + head_size_k: tl.constexpr, + head_size_v: tl.constexpr, block_size: tl.constexpr, # FP8 flags FP8_KV_CACHE: tl.constexpr, @@ -211,24 +214,51 @@ def reshape_and_cache_kernel_flash_diffkv( tile_i = tl.program_id(axis=1) tile_offs = tl.arange(0, TILE_SIZE) - tile_pos = tile_i * TILE_SIZE + tile_offs block_idx = slot_idx // block_size block_offset = slot_idx % block_size - src_kv_idx = token_idx * kv_stride + src_key_idx = token_idx * key_stride + tile_i * head_size_k + src_value_idx = token_idx * value_stride + tile_i * head_size_v - tgt_idx = block_idx * block_stride + block_offset * page_stride - - # [TILE_SIZE] - kv_tile = tl.load( - kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv) + tgt_idx = ( + block_idx * block_stride + + block_offset * page_stride + + tile_i * (head_size_k + head_size_v) ) + # [TILE_SIZE] + key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k) + if FP8_KV_CACHE: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load( + value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v + ) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + tl.store( - kv_cache_ptr + tgt_idx + tile_pos, - kv_tile, - mask=tile_pos < (num_heads * head_size_kv), + kv_cache_ptr + tgt_idx + tile_offs, + key_tile, + mask=tile_offs < head_size_k, + ) + tl.store( + kv_cache_ptr + tgt_idx + head_size_k + tile_offs, + value_tile, + mask=tile_offs < head_size_v, ) return @@ -243,13 +273,13 @@ def triton_reshape_and_cache_flash_diffkv( k_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32 ): - kv = torch.cat([key, value], dim=-1).contiguous() - num_heads = kv.shape[1] - head_size_kv = kv.shape[2] + num_heads = key.shape[1] + head_size_k = key.shape[2] + head_size_v = value.shape[2] block_size = kv_cache.shape[1] - n = num_heads * head_size_kv - kv_stride = kv.stride()[0] + k_stride = key.stride()[0] + v_stride = value.stride()[0] block_stride = kv_cache.stride()[0] page_stride = kv_cache.stride()[1] @@ -272,13 +302,20 @@ def triton_reshape_and_cache_flash_diffkv( ) FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") - assert not FP8_KV_CACHE, ( + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + torch.float8_e4m3fnuz, + ], ( "unsupported dtype of KV cache tensor, got " - "{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32." + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." ) # heuristics instead of autotuning - TILE_SIZE = min(2048, triton.next_power_of_2(n)) + TILE_SIZE = max(head_size_k, head_size_v) + TILE_SIZE = triton.next_power_of_2(TILE_SIZE) if current_platform.is_rocm() or current_platform.is_xpu(): num_stages = 4 num_warps = 8 @@ -292,21 +329,24 @@ def triton_reshape_and_cache_flash_diffkv( # using cudagraphs grid = lambda meta: ( slot_mapping.shape[0], - triton.cdiv(n, meta["TILE_SIZE"]), + num_heads, ) reshape_and_cache_kernel_flash_diffkv[grid]( - kv_ptr=kv, + key_ptr=key, + value_ptr=value, kv_cache_ptr=kv_cache, slot_mapping_ptr=slot_mapping, k_scale=k_scale, v_scale=v_scale, # strides - kv_stride=kv_stride, + key_stride=k_stride, + value_stride=v_stride, block_stride=block_stride, page_stride=page_stride, num_heads=num_heads, - head_size_kv=head_size_kv, + head_size_k=head_size_k, + head_size_v=head_size_v, block_size=block_size, # FP8 flags FP8_KV_CACHE=FP8_KV_CACHE, diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 619981eeccd7c..8e4bb62e137a8 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -29,8 +29,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, AttentionType +from vllm.attention.layers.static_sink_attention import StaticSinkAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import ( @@ -41,7 +41,6 @@ from vllm.distributed import ( get_tp_group, tensor_model_parallel_all_gather, ) -from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -82,11 +81,7 @@ from vllm.model_executor.models.utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta -from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend -from vllm.v1.kv_cache_interface import ( - FullDiffkvAttentionSpec, - KVCacheSpec, -) +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend def check_ffn_act_fn(act_fn: str): @@ -96,133 +91,6 @@ def check_ffn_act_fn(act_fn: str): ) -class DiffkvAttention(Attention): - def __init__( - self, - num_heads: int, - head_size: int, - head_size_v: int, - scale: float, - num_kv_heads: int | None = None, - alibi_slopes: list[float] | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - logits_soft_cap: float | None = None, - per_layer_sliding_window: int | None = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - attn_backend: type[AttentionBackend] | None = None, - **extra_impl_args, - ) -> None: - super().__init__( - num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - cache_config, - quant_config, - logits_soft_cap, - per_layer_sliding_window, - prefix, - attn_type, - kv_sharing_target_layer_name, - attn_backend, - **extra_impl_args, - ) - self.head_size_v = head_size_v - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output_shape: torch.Size | None = None, - ) -> torch.Tensor: - """ - The KV cache is stored inside this class and is accessed via - `self.kv_cache`. - - Attention metadata (`attn_metadata`) is set using a context manager in - the model runner's `execute_model` method. It is accessed via forward - context using - `vllm.forward_context.get_forward_context().attn_metadata`. - """ - if self.calculate_kv_scales: - torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) - output_dtype = query.dtype - if self.query_quant is not None: - # quantizing with a simple torch operation enables - # torch.compile to fuse this into previous ops - # which reduces overheads during decoding. - # Otherwise queries are quantized using custom ops - # which causes decoding overheads - assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} - - # check if query quantization is supported - if self.impl.supports_quant_query_input(): - query, _ = self.query_quant(query, self._q_scale) - - if self.use_output: - output_shape = output_shape if output_shape is not None else query.shape - output = torch.empty(output_shape, dtype=output_dtype, device=query.device) - hidden_size = output_shape[-1] - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size_v) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size_v) - if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward( - self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output, - ) - else: - torch.ops.vllm.unified_attention_with_output( - query, - key, - value, - output, - self.layer_name, - ) - return output.view(-1, hidden_size) - else: - raise ValueError( - "Unsupport Error, currently only flash_diffkv_attn " - "backend with output buffer is supported" - ) - - def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: - # Block size may get updated after model loading, refresh it - block_size = vllm_config.cache_config.block_size - # Should not be called for enc-dec or encoder-only attention. - assert self.attn_type == AttentionType.DECODER - # Only support for full attention now. - assert self.sliding_window is None - return FullDiffkvAttentionSpec( - block_size=block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - head_size_v=self.head_size_v, - dtype=self.kv_cache_torch_dtype, - ) - - class OpenPanguMLP(nn.Module): def __init__( self, @@ -673,7 +541,7 @@ class OpenPanguEmbeddedAttention(nn.Module): ) -class OpenPanguDiffkvAttention(nn.Module): +class OpenPanguSinkAttention(nn.Module): def __init__( self, config: PretrainedConfig, @@ -777,19 +645,19 @@ class OpenPanguDiffkvAttention(nn.Module): else: sliding_window = None - FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels) - self.attn = DiffkvAttention( + self.attn = StaticSinkAttention( self.num_heads, self.head_dim, - self.v_channels, self.scaling, + sink_len=self.param_sink_number, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - attn_backend=FlashDiffkvAttentionBackend, + attn_backend=FlashAttentionBackend, + head_size_v=self.v_channels, ) if self.param_sink_number > 0: @@ -919,19 +787,13 @@ class OpenPanguDiffkvAttention(nn.Module): is_neox_style=is_neox_style, ) - def get_sink_kv(self) -> dict[str, torch.Tensor]: - if self.param_sink_number == 0: - raise ValueError("No sink_key and sink_value when param_sink_number == 0") - + def post_weight_load(self) -> None: if hasattr(self, "k_layernorm") and self.k_layernorm is not None: param_sink_key = self.k_layernorm(self.param_sink_key) else: param_sink_key = self.param_sink_key - return { - "sink_key": param_sink_key, - "sink_value": self.param_sink_value, - } + self.attn.update_sink_kv(param_sink_key, self.param_sink_value) class OpenPanguDecoderLayer(nn.Module): @@ -1001,7 +863,7 @@ class OpenPanguDecoderLayer(nn.Module): "rope_type": "default", "rope_theta": config.rope_theta, } - self.self_attn = OpenPanguDiffkvAttention( + self.self_attn = OpenPanguSinkAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -1359,8 +1221,17 @@ class OpenPanguModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + self.post_weight_load() return loaded_params + def post_weight_load(self) -> None: + for name, module in self.named_modules(): + if module is self: + continue + if hasattr(module, "post_weight_load"): + module.post_weight_load() + class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f5ad98cf2125c..8b030a04b438d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,6 +18,9 @@ from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, get_flash_attn_version, @@ -105,28 +108,48 @@ class FlashAttentionBackend(AttentionBackend): num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", + head_size_v: int | None = None, ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + if head_size_v is None or head_size == head_size_v: + return (2, num_blocks, block_size, num_kv_heads, head_size) + else: + return ( + num_blocks, + block_size, + num_kv_heads, + head_size + head_size_v, + ) @staticmethod def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, + diff_kv: bool = False, ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) - return (2, 0, 1, 3, 4, 5) + if not diff_kv: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + else: + # (num_blocks, num_layers, block_size, + # num_kv_heads, head_size + head_size_v) + return (0, 1, 2, 3, 4) elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3, 4) + stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3) elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 4, 0, 1, 3, 5) + if not diff_kv: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) + else: + # (num_blocks, num_kv_heads, num_layers, + # block_size, head_size + head_size_v) + return (2, 3, 0, 1, 4) elif cache_layout == "HND": - stride_order = (0, 1, 3, 2, 4) + stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3) else: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order @@ -576,11 +599,14 @@ class FlashAttentionImpl(AttentionImpl): query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] + or [num_tokens, num_kv_heads, head_size_v] kv_cache: shape = [2, num_blocks, block_size, num_kv_heads, head_size] + or [num_blocks, block_size, num_kv_heads, head_size + head_size_v] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] + or [num_tokens, num_heads * head_size_v] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values @@ -623,7 +649,13 @@ class FlashAttentionImpl(AttentionImpl): ) # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(0) + if self.head_size == kv_cache.shape[-1]: + # Same head_size for K and V + key_cache, value_cache = kv_cache.unbind(0) + else: + # Different head_size for K and V + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached @@ -640,16 +672,29 @@ class FlashAttentionImpl(AttentionImpl): # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if self.head_size == kv_cache.shape[-1]: + # kv_cache update for same head_size K and V + reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # kv_cache update for different head_size K and V + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer diff --git a/vllm/v1/attention/backends/flash_diffkv_attn.py b/vllm/v1/attention/backends/flash_diffkv_attn.py deleted file mode 100644 index acd9cbcb4cabf..0000000000000 --- a/vllm/v1/attention/backends/flash_diffkv_attn.py +++ /dev/null @@ -1,1031 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" - -from typing import ClassVar - -import numpy as np -import torch - -from vllm import envs -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionType, - MultipleOf, - is_quantized_kv_cache, -) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs -from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) -from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available, -) - -if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_sinks, - flash_attn_varlen_func, - get_scheduler_metadata, - ) -from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config -from vllm.config.cache import CacheDType -from vllm.distributed.parallel_state import get_dcp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) -from vllm.platforms.interface import DeviceCapability -from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_dcp_local_seq_lens, - get_kv_cache_layout, -) -from vllm.v1.kv_cache_interface import AttentionSpec - -from .flash_attn import FlashAttentionMetadata - -logger = init_logger(__name__) - - -class FlashDiffkvAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - # TODO: Remove hard code - head_size_v: int = 128 - - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - vllm_config = get_current_vllm_config() - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - if ( - model_config - and model_config.is_hybrid - and ( - cache_config.mamba_ssm_cache_dtype == "float32" - or cache_config.mamba_cache_dtype == "float32" - ) - ): - # NOTE(tdoublep): while in principle, FA supports - # MultipleOf(16), these are the block sizes that do not - # suffer from the NaN propagation problem described here: - # https://github.com/Dao-AILab/flash-attention/issues/1974 - return [16, 32, 64] - return [MultipleOf(16)] - - @staticmethod - def get_name() -> str: - return "FLASH_DIFFKV_ATTN" - - @classmethod - def supports_attn_type(cls, attn_type: str) -> bool: - """FlashDiffkvAttention supports all attention types.""" - from vllm.attention import AttentionType - - return attn_type in ( - AttentionType.DECODER, - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - AttentionType.ENCODER_DECODER, - ) - - @staticmethod - def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]: - return FlashDiffkvAttentionImpl - - @staticmethod - def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]: - return FlashDiffkvAttentionMetadataBuilder - - @classmethod - def set_head_size_v(cls, head_size_v: int) -> None: - cls.head_size_v = head_size_v - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return ( - num_blocks, - block_size, - num_kv_heads, - head_size + FlashDiffkvAttentionBackend.head_size_v, - ) - - @staticmethod - def get_kv_cache_stride_order( - include_num_layers_dimension: bool = False, - ) -> tuple[int, ...]: - # `stride_order` indicates the permutation that gets - # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_cache_layout() - if cache_layout == "NHD" and include_num_layers_dimension: - # (num_blocks, num_layers, block_size, num_kv_heads, head_size) - return (0, 1, 2, 3, 4) - elif cache_layout == "NHD": - stride_order = (0, 1, 2, 3) - elif cache_layout == "HND" and include_num_layers_dimension: - # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) - return (2, 3, 0, 1, 4) - elif cache_layout == "HND": - stride_order = (0, 2, 1, 3) - else: - raise ValueError(f"Unknown cache layout format {cache_layout}.") - return stride_order - - @staticmethod - def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: - if kv_cache_dtype in ("fp8", "fp8_e4m3"): - return torch.float8_e4m3fn - else: - raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") - - @classmethod - def supports_head_size(cls, head_size: int) -> bool: - return head_size % 8 == 0 and head_size <= 256 - - @classmethod - def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: - if kv_cache_dtype is None: - return True - if kv_cache_dtype.startswith("fp8"): - return flash_attn_supports_fp8() - return kv_cache_dtype in ["auto"] - - @classmethod - def supports_sink(cls) -> bool: - if not is_flash_attn_varlen_func_available(): - return False - return flash_attn_supports_sinks() - - @classmethod - def supports_compute_capability(cls, capability: DeviceCapability) -> bool: - return capability >= DeviceCapability(8, 0) - - @classmethod - def supports_combination( - cls, - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: CacheDType | None, - block_size: int, - use_mla: bool, - has_sink: bool, - use_sparse: bool, - device_capability: DeviceCapability, - ) -> str | None: - if has_sink and device_capability < DeviceCapability(9, 0): - return "sink not supported on compute capability < 9.0" - return None - - -def _get_sliding_window_configs( - vllm_config: VllmConfig, -) -> set[tuple[int, int] | None]: - """Get the set of all sliding window configs used in the model.""" - sliding_window_configs: set[tuple[int, int] | None] = set() - layers = get_layers_from_vllm_config(vllm_config, Attention) - for layer in layers.values(): - assert isinstance(layer.impl, FlashDiffkvAttentionImpl) - sliding_window_configs.add(layer.impl.sliding_window) - return sliding_window_configs - - -class FlashDiffkvAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata] -): - # FA3: - # Supports full cudagraphs for all cases. - # - # FA2: - # For FA2, a graph is captured with max_query_len=1, (which is what we - # capture by default for num_tokens <= max_num_seqs when there is no - # spec-decode) then these graphs will not work for mixed prefill-decode - # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling - # in FA2. - # In summary if we are running with spec decodes the graphs would - # work for mixed prefill-decode and uniform-decode. But for non-spec decodes - # the graphs would not work for mixed prefill-decode; sorta the inverse - # of UNIFORM_SINGLE_TOKEN_DECODE. - # There's probably a better way to describe this using `AttentionCGSupport` - # but for now just set it to `UNIFORM_BATCH` to get use to drop down - # to FULL_AND_PIECEWISE. - # TODO(luka, lucas): audit FA2 as part of: - # https://github.com/vllm-project/vllm/issues/22945 - _cudagraph_support = ( - AttentionCGSupport.ALWAYS - if get_flash_attn_version() == 3 - else AttentionCGSupport.UNIFORM_BATCH - ) - - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.model_config = vllm_config.model_config - self.parallel_config = vllm_config.parallel_config - self.cache_config = vllm_config.cache_config - self.compilation_config = vllm_config.compilation_config - - self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config - ) - self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) - self.kv_cache_dtype = kv_cache_spec.dtype - self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size - - self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = get_flash_attn_version() == 3 - - try: - from vllm.distributed.parallel_state import get_dcp_group - - self.dcp_world_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group - except AssertionError: - # DCP might not be initialized in testing - self.dcp_world_size = 1 - self.dcp_rank = 0 - - self.cp_kv_cache_interleave_size = ( - self.parallel_config.cp_kv_cache_interleave_size - ) - - self.use_full_cuda_graph = ( - self.compilation_config.cudagraph_mode.has_full_cudagraphs() - ) - self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size - - if self.use_full_cuda_graph and self.aot_schedule: - self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, - dtype=torch.int32, - device=self.device, - ) - # When using cuda graph, we need to set the upper bound of the - # number of splits so that large enough intermediate buffers are - # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - - # Sliding window size to be used with the AOT scheduler will be - # populated on first build() call. - self.aot_sliding_window: tuple[int, int] | None = None - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> FlashAttentionMetadata: - """ - fast_build disables AOT scheduling, used when there will be few - iterations i.e. spec-decode - """ - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - causal = common_attn_metadata.causal - - # the overhead of the aot schedule is not worth it for spec-decode - aot_schedule = self.aot_schedule and not fast_build - - if self.aot_sliding_window is None: - self.aot_sliding_window = (-1, -1) - # For the AOT scheduler we need the sliding window value to be - # constant for all layers to. We have to populate this on the first - # build() call so the layers are constructed (cannot populate) - # in __init__. - if aot_schedule: - sliding_window_configs = _get_sliding_window_configs(self.vllm_config) - if len(sliding_window_configs) == 1: - sliding_window_config = sliding_window_configs.pop() - if sliding_window_config is not None: - self.aot_sliding_window = sliding_window_config - elif len(sliding_window_configs) > 1: - self.aot_schedule = False - aot_schedule = False - - max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - - if vllm_is_batch_invariant(): - max_num_splits = 1 - - def schedule( - batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal - ): - cache_dtype = self.cache_config.cache_dtype - if cache_dtype.startswith("fp8"): - qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype - ) - else: - qkv_dtype = self.kv_cache_dtype - if aot_schedule: - return get_scheduler_metadata( - batch_size=batch_size, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q * self.dcp_world_size, - num_heads_kv=self.num_heads_kv, - headdim=self.headdim, - cache_seqlens=seqlens, - qkv_dtype=qkv_dtype, - cu_seqlens_q=cu_query_lens, - page_size=self.block_size, - causal=causal, - window_size=self.aot_sliding_window, - num_splits=max_num_splits, - ) - return None - - use_cascade = common_prefix_len > 0 - max_dcp_context_kv_len = 0 - dcp_context_kv_lens = None - - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None - - if self.dcp_world_size > 1: - query_kv_lens_cpu = ( - common_attn_metadata.query_start_loc_cpu[1:] - - common_attn_metadata.query_start_loc_cpu[:-1] - ) - dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu - - dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( - dcp_context_kv_lens_cpu, - self.dcp_world_size, - self.dcp_rank, - self.cp_kv_cache_interleave_size, - ) - dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) - max_dcp_context_kv_len = dcp_context_kv_lens.max().item() - - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=dcp_context_kv_lens, - max_seq_len=max_dcp_context_kv_len, - causal=False, - ) - elif use_cascade: - cu_prefix_query_lens = torch.tensor( - [0, num_actual_tokens], dtype=torch.int32, device=self.device - ) - prefix_kv_lens = torch.tensor( - [common_prefix_len], dtype=torch.int32, device=self.device - ) - suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True - ) - prefix_scheduler_metadata = schedule( - batch_size=1, - cu_query_lens=cu_prefix_query_lens, - max_query_len=num_actual_tokens, - seqlens=prefix_kv_lens, - max_seq_len=common_prefix_len, - causal=False, - ) - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - common_prefix_len, - causal=True, - ) - else: - scheduler_metadata = schedule( - batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal, - ) - # For FA3 + full cudagraph - if self.use_full_cuda_graph and scheduler_metadata is not None: - n = scheduler_metadata.shape[0] - self.scheduler_metadata[:n] = scheduler_metadata - # NOTE(woosuk): We should zero out the rest of the scheduler - # metadata to guarantee the correctness. Otherwise, some thread - # blocks may use the invalid scheduler metadata and overwrite the - # output buffer. - self.scheduler_metadata[n:] = 0 - scheduler_metadata = self.scheduler_metadata[:n] - - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - max_dcp_context_kv_len=max_dcp_context_kv_len, - dcp_context_kv_lens=dcp_context_kv_lens, - use_cascade=use_cascade, - common_prefix_len=common_prefix_len, - scheduler_metadata=scheduler_metadata, - cu_prefix_query_lens=cu_prefix_query_lens, - prefix_kv_lens=prefix_kv_lens, - suffix_kv_lens=suffix_kv_lens, - prefix_scheduler_metadata=prefix_scheduler_metadata, - max_num_splits=max_num_splits, - causal=causal, - ) - return attn_metadata - - def use_cascade_attention(self, *args, **kwargs) -> bool: - return use_cascade_attention(*args, **kwargs) - - -class FlashDiffkvAttentionImpl(AttentionImpl): - can_return_lse_for_decode: bool = True - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None = None, - attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - sinks: torch.Tensor | None = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - elif attn_type == AttentionType.ENCODER_ONLY: - self.sliding_window = (sliding_window - 1, sliding_window - 1) - else: - self.sliding_window = (sliding_window - 1, 0) - self.kv_cache_dtype = kv_cache_dtype - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.attn_type = attn_type - self.vllm_flash_attn_version = get_flash_attn_version() - # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() - - if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device." - ) - - self.sinks = sinks - if self.sinks is not None: - assert flash_attn_supports_sinks(), ( - "Sinks are only supported in FlashAttention 3" - ) - assert self.sinks.shape[0] == num_heads, ( - "Sinks must have the same number of heads as the number of " - "heads in the layer" - ) - - def supports_quant_query_input(self) -> bool: - return True - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size_v] - kv_cache: shape = - [num_blocks, block_size, num_kv_heads, head_size + head_size_v] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size_v] - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for" - "FlashDiffkvAttentionImpl" - ) - - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - - attn_type = self.attn_type - - # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. - # Minimize the PyTorch ops in this method as much as possible. - # Whenever making a change in this method, please benchmark the - # performance to make sure it does not introduce any overhead. - - num_actual_tokens = attn_metadata.num_actual_tokens - - # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): - # For encoder attention, - # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention( - query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, - layer, - ) - - # For decoder and cross-attention, use KV cache as before - key_cache = kv_cache[..., : self.head_size] - value_cache = kv_cache[..., self.head_size :] - - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - triton_reshape_and_cache_flash_diffkv( - key, - value, - kv_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.kv_cache_dtype.startswith("fp8"): - # queries are quantized in the attention layer - dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype - ) - key_cache = key_cache.view(dtype) - value_cache = value_cache.view(dtype) - - if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - scheduler_metadata = attn_metadata.scheduler_metadata - - descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - - if self.dcp_world_size > 1: - self._forward_with_dcp( - query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - key_cache, - value_cache, - output[:num_actual_tokens], - attn_metadata, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - return output - else: - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) - return output - - # Cascade attention (rare case). - cascade_attention( - output[:num_actual_tokens], - query[:num_actual_tokens], - key_cache, - value_cache, - cu_query_lens=attn_metadata.query_start_loc, - max_query_len=attn_metadata.max_query_len, - cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, - prefix_kv_lens=attn_metadata.prefix_kv_lens, - suffix_kv_lens=attn_metadata.suffix_kv_lens, - max_kv_len=attn_metadata.max_seq_len, - softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window, - logits_soft_cap=self.logits_soft_cap, - block_table=attn_metadata.block_table, - common_prefix_len=attn_metadata.common_prefix_len, - max_num_splits=attn_metadata.max_num_splits, - fa_version=self.vllm_flash_attn_version, - prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, - suffix_scheduler_metadata=attn_metadata.scheduler_metadata, - q_descale=layer._q_scale, - k_descale=layer._k_scale, - v_descale=layer._v_scale, - s_aux=self.sinks, - ) - return output - - def _forward_with_dcp( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - ) -> torch.Tensor: - cu_seqlens_q = attn_metadata.query_start_loc - max_seqlen_q = attn_metadata.max_query_len - block_table = attn_metadata.block_table - - query = query.contiguous() - query_across_dcp = get_dcp_group().all_gather(query, dim=1) - context_attn_out, context_lse = flash_attn_varlen_func( - q=query_across_dcp, - k=key_cache, - v=value_cache, - out=None, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=attn_metadata.dcp_context_kv_lens, - max_seqlen_k=attn_metadata.max_dcp_context_kv_len, - softmax_scale=self.scale, - causal=False, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=attn_metadata.scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] - context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( - context_attn_out, - context_lse.transpose(0, 1), - get_dcp_group(), - return_lse=True, - ) - context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() - - query_attn_out, query_lse = flash_attn_varlen_func( - q=query, - k=key, - v=value, - out=None, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - cu_seqlens_k=cu_seqlens_q, - max_seqlen_k=max_seqlen_q, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - softcap=self.logits_soft_cap, - return_softmax_lse=True, - fa_version=self.vllm_flash_attn_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - assert context_attn_out_cor.shape == query_attn_out.shape - assert context_lse_cor.shape == query_lse.shape - merge_attn_states( - output, - context_attn_out_cor, - context_lse_cor, - query_attn_out, - query_lse, - ) - - def _forward_encoder_attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - layer: torch.nn.Module, - ) -> torch.Tensor: - """Forward pass for encoder attention without KV cache. - - Args: - query: shape = [num_encoder_tokens, num_heads, head_size] - key: shape = [num_encoder_tokens, num_kv_heads, head_size] - value: shape = [num_encoder_tokens, num_kv_heads, head_size] - output: shape = [num_encoder_tokens, num_heads, head_size] - attn_metadata: Encoder attention metadata - layer: The attention layer - """ - # For encoder attention, process FP8 quantization if needed - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError( - "quantization is not supported for encoder attention" - ) - - # Use encoder-specific metadata for sequence information - cu_seqlens_q = attn_metadata.query_start_loc - cu_seqlens_k = attn_metadata.query_start_loc - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_query_len - - descale_shape = ( - cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads, - ) - - # Call flash attention directly on Q, K, V tensors - flash_attn_varlen_func( - q=query, - k=key, - v=value, - out=output, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=False, # Encoder attention is bidirectional - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=1 if self.batch_invariant_enabled else 0, - ) - - return output - - -def use_cascade_attention( - common_prefix_len: int, - query_lens: np.ndarray, - num_query_heads: int, - num_kv_heads: int, - use_alibi: bool, - use_sliding_window: bool, - use_local_attention: bool, - num_sms: int, - dcp_world_size: int, -) -> bool: - """Decide whether to use cascade attention. - - This function 1) checks whether cascade attention is supported with the - given configuration, and 2) heuristically decides whether using cascade - attention can improve performance. - """ - # Too short common prefix. Probably not worth using cascade attention. - # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. - # NOTE(woosuk): This is the common case. We should return False as soon as - # possible to avoid any unnecessary computation. - if common_prefix_len < 256: - return False - # Cascade attention is currently not supported with these variants. - if use_alibi or use_sliding_window or use_local_attention: - return False - # Too few queries. Probably not worth using cascade attention. - # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. - num_reqs = len(query_lens) - if num_reqs < 8: - return False - # disable cascade attention for DCP - if dcp_world_size > 1: - return False - - # Heuristics to decide whether using cascade attention is beneficial. - # 1. When FlashDecoding is not used for normal attention, cascade attention - # is likely to be faster since it saves memory bandwidth. - num_queries_per_kv = num_query_heads // num_kv_heads - # The criteria for using FlashDecoding can be found in the following link: - # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = ( - num_queries_per_kv > 1 - and not use_sliding_window - and not use_alibi - and np.all(query_lens == 1) - ) - if not use_flash_decoding: - # Use cascade attention. - return True - - # 2. When FlashDecoding is used for normal attention, it is not clear - # whether cascade attention is beneficial, because FlashDecoding can - # launch more CTAs than cascade attention. - # We use a simple performance model to compare the two methods. - # NOTE(woosuk): The performance model is very rough and may not be - # accurate. - num_tokens = num_reqs - # NOTE(woosuk): These are default tile sizes. flash-attn might use - # different tile sizes (e.g., 64 or 256) depending on the configuration. - q_tile_size = 128 - kv_tile_size = 128 - num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) - - cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) - cascade_waves = cdiv(cascade_ctas, num_sms) - cascade_time = cascade_waves * num_prefix_tiles - - flash_decoding_ctas = ( - num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) - ) - flash_decoding_ctas *= num_prefix_tiles - flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) - - # Use cascade attention if it is faster than FlashDecoding. - return cascade_time < flash_decoding_time - - -def cascade_attention( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cu_query_lens: torch.Tensor, - max_query_len: int, - cu_prefix_query_lens: torch.Tensor, - prefix_kv_lens: torch.Tensor, - suffix_kv_lens: torch.Tensor, - max_kv_len: int, - softmax_scale: float, - alibi_slopes: torch.Tensor | None, - sliding_window: tuple[int, int], - logits_soft_cap: float, - block_table: torch.Tensor, - common_prefix_len: int, - max_num_splits: int, - fa_version: int, - prefix_scheduler_metadata: torch.Tensor | None = None, - suffix_scheduler_metadata: torch.Tensor | None = None, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - s_aux: torch.Tensor | None = None, -) -> torch.Tensor: - assert alibi_slopes is None, "Cascade attention does not support ALiBi." - # TODO: Support sliding window. - assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window." - ) - - num_tokens = query.shape[0] - block_size = key_cache.shape[-3] - assert common_prefix_len % block_size == 0 - num_common_kv_blocks = common_prefix_len // block_size - assert num_common_kv_blocks > 0 - descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) - - # Process shared prefix. - prefix_output, prefix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_prefix_query_lens, - seqused_k=prefix_kv_lens, - max_seqlen_q=num_tokens, - max_seqlen_k=common_prefix_len, - softmax_scale=softmax_scale, - causal=False, - window_size=sliding_window, - block_table=block_table[:1], - softcap=logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=prefix_scheduler_metadata, - fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - # s_aux is incorporated into prefix_lse inside the GPU kernel, - # enabling its effect during the final attention merge. - s_aux=s_aux, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, - ) - - descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) - - # Process suffix per query. - suffix_output, suffix_lse = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - seqused_k=suffix_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len - common_prefix_len, - softmax_scale=softmax_scale, - causal=True, - window_size=sliding_window, - block_table=block_table[:, num_common_kv_blocks:], - softcap=logits_soft_cap, - return_softmax_lse=True, - scheduler_metadata=suffix_scheduler_metadata, - fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, - ) - - # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f23b5564743c9..a9ce6e63cc775 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -210,11 +210,6 @@ class Scheduler(SchedulerInterface): hash_block_size=self.block_size, metrics_collector=self.kv_metrics_collector, ) - sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0) - if sink_len > 0: - assert sink_len % self.block_size == 0 - num_sink_block = sink_len // self.block_size - self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8444ee5ef425f..e6f65da36e413 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,10 +12,10 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, - FullDiffkvAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, + SinkFullAttentionSpec, SlidingWindowSpec, ) from vllm.v1.request import Request @@ -317,8 +317,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( - kv_cache_spec, - FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec, + kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -785,14 +784,35 @@ class CrossAttentionManager(SingleTypeKVCacheManager): raise NotImplementedError("CrossAttentionManager does not support caching") +class SinkFullAttentionManager(FullAttentionManager): + def __init__( + self, + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, + kv_cache_group_id: int, + dcp_world_size: int = 1, + pcp_world_size: int = 1, + ): + super().__init__( + kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size + ) + sink_len = kv_cache_spec.sink_len + if sink_len > 0: + assert sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.sink_blocks = self.block_pool.free_block_queue.popleft_n( + num_sink_block + ) + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, - FullDiffkvAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, + SinkFullAttentionSpec: SinkFullAttentionManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 1b130300b2218..656f5e7b81f55 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -80,6 +80,7 @@ class AttentionSpec(KVCacheSpec): @dataclass(frozen=True) class FullAttentionSpec(AttentionSpec): + head_size_v: int | None = None sliding_window: int | None = None attention_chunk_size: int | None = None """ @@ -92,6 +93,10 @@ class FullAttentionSpec(AttentionSpec): Default to None for not using sliding window attention. """ + def __post_init__(self): + if self.head_size_v is None: + object.__setattr__(self, "head_size_v", self.head_size) + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -124,88 +129,6 @@ class FullAttentionSpec(AttentionSpec): "All attention layers in the same KV cache group must be FullAttentionSpec." ) - sliding_window = set( - spec.sliding_window for spec in specs if spec.sliding_window is not None - ) - attention_chunk_size = set( - spec.attention_chunk_size - for spec in specs - if spec.attention_chunk_size is not None - ) - assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" - ) - merged_spec = cls( - block_size=specs[0].block_size, - num_kv_heads=specs[0].num_kv_heads, - head_size=specs[0].head_size, - dtype=specs[0].dtype, - sliding_window=cls.merge_window_sizes(sliding_window), - attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), - ) - for spec in specs: - for f in fields(AttentionSpec): - assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( - "All attention layers in the same KV cache group must have " - "the same attention spec." - ) - assert (merged_spec.sliding_window is not None) + ( - merged_spec.attention_chunk_size is not None - ) <= 1, ( - "Model with both sliding window layers and chunked local attention " - "layers is not supported." - ) - return merged_spec - - -@dataclass(frozen=True) -class FullDiffkvAttentionSpec(AttentionSpec): - head_size_v: int - sliding_window: int | None = None - attention_chunk_size: int | None = None - - """ - When hybrid allocator is disabled and the model contains both full - attention layers and sliding window attention layers, sliding - window attention are regarded as full attention in KV cache manager - (blocks are allocated for all tokens), while computed as sliding window - attention in model runner. - In this case, we use FullDiffkvAttentionSpec and record the sliding window size. - Default to None for not using sliding window attention. - """ - - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - max_model_len = vllm_config.model_config.max_model_len - dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size - # Note(hc): each dcp rank only need save - # (max_model_len//dcp_world_size) tokens locally. - if dcp_world_size > 1: - max_model_len = cdiv(max_model_len, dcp_world_size) - return cdiv(max_model_len, self.block_size) * self.page_size_bytes - - @classmethod - def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: - if len(window_sizes) == 0: - return None - elif len(window_sizes) == 1: - return window_sizes.pop() - else: - raise ValueError( - "All attention layers in the same KV cache group must have the " - "same window size." - ) - - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of FullDiffkvAttentionSpec objects into a single - FullDiffkvAttentionSpec object. - """ - assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "FullDiffkvAttentionSpec." - ) - sliding_window = set( spec.sliding_window for spec in specs if spec.sliding_window is not None ) @@ -376,6 +299,56 @@ class CrossAttentionSpec(AttentionSpec): return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes +@dataclass(forzen=True) +class SinkFullAttentionSpec(FullAttentionSpec): + sink_len: int | None = None + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) + + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, + sink_len=specs[0].sink_len, + dtype=specs[0].dtype, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) + return merged_spec + + @dataclass(frozen=True) class UniformTypeKVCacheSpecs(KVCacheSpec): """ diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 37ec0fb97e06b..dd61d2150a797 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -263,6 +263,7 @@ class MultiGroupBlockTable: kernel_block_sizes: list[int], num_speculative_tokens: int = 0, cp_kv_cache_interleave_size: int = 1, + sink_len: int = 0, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -292,7 +293,7 @@ class MultiGroupBlockTable: block_size, max_num_reqs, max( - cdiv(max_model_len, block_size * total_cp_world_size), + cdiv(max_model_len + sink_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, ), max_num_batched_tokens, diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 6386f1a08b446..09a5bd885309d 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -101,16 +101,25 @@ def _reshape_kv_cache( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes attn_backend = attn_backends[layer_name] + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, + **kwargs, ) # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + **stride_kwargs + ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c2b4b0dac3033..c567fc7219c3b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -143,7 +143,7 @@ class InputBatch: # Block table. self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len + sink_len, + max_model_len=max_model_len, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, @@ -151,6 +151,7 @@ class InputBatch: kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, + sink_len=sink_len, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ebdb9daf7fae9..5f81f9ba23ddc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,9 +27,6 @@ from vllm.attention.backends.abstract import ( MultipleOf, ) from vllm.attention.layer import Attention, MLAAttention -from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash_diffkv, -) from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled @@ -5209,16 +5206,25 @@ class GPUModelRunner( ) kernel_num_blocks = num_blocks * num_blocks_per_kv_block + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, + **kwargs, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + **stride_kwargs + ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) @@ -5410,7 +5416,6 @@ class GPUModelRunner( kv_caches = self.initialize_kv_cache_tensors( kv_cache_config, kernel_block_sizes ) - self.prepare_sink_kv_cache(kv_caches) if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) @@ -5501,36 +5506,3 @@ class GPUModelRunner( self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() - - def prepare_sink_kv_cache(self, kv_caches) -> None: - if self.sink_len == 0: - return - - def find_module_by_name(model, target_name: str): - for name, module in model.named_modules(): - if name == target_name: - return module - raise KeyError(f"Module '{target_name}' not found") - - for layer_name, kv_cache in kv_caches.items(): - layer_prefix = layer_name.rsplit(".", 1)[0] - self_attn_module = find_module_by_name(self.model, layer_prefix) - if not hasattr(self_attn_module, "get_sink_kv"): - continue - else: - sink_kv = self_attn_module.get_sink_kv() - sink_kv_slot_mapping = torch.arange( - self.vllm_config.cache_config.block_size, - self.sink_len + self.vllm_config.cache_config.block_size, - device=torch.cuda.current_device(), - dtype=torch.long, - ) - triton_reshape_and_cache_flash_diffkv( - sink_kv["sink_key"], - sink_kv["sink_value"], - kv_cache, - sink_kv_slot_mapping, - self_attn_module.attn.kv_cache_dtype, - self_attn_module.attn._k_scale, - self_attn_module.attn._v_scale, - ) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 2bcc87b63bcdf..70e99db9e9762 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -190,17 +190,25 @@ class KVConnectorModelRunnerMixin: return False attn_backend = attn_group.backend + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( 1234, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, + **kwargs, ) try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True + include_num_layers_dimension=True, + **stride_kwargs, ) except (AttributeError, NotImplementedError): return False @@ -257,12 +265,19 @@ class KVConnectorModelRunnerMixin: kernel_num_blocks = num_blocks * num_blocks_per_kv_block attn_backend = attn_group.backend + if hasattr(kv_cache_spec, "head_size_v"): + kwargs = {"head_size_v": kv_cache_spec.head_size_v} + stride_kwargs = {"diff_kv": True} + else: + kwargs = {} + stride_kwargs = {} kv_cache_shape = attn_backend.get_kv_cache_shape( kernel_num_blocks, kernel_block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=cache_dtype, + **kwargs, ) # prepend a num_layers dimension into the shape @@ -270,7 +285,8 @@ class KVConnectorModelRunnerMixin: try: kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( - include_num_layers_dimension=True + include_num_layers_dimension=True, + **stride_kwargs, ) assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError):