Refacotr code. Extent FLASH_ATTN to support different KV size and create a new StaticSinkAttention for sink token logics

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-12-13 15:47:33 +08:00
parent de538d3b8f
commit b565203d92
15 changed files with 503 additions and 1362 deletions

View File

@ -42,9 +42,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASH_DIFFKV_ATTN = (
"vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"

View File

@ -191,6 +191,7 @@ class Attention(nn.Module, AttentionLayerBase):
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args,
) -> None:
"""
@ -232,6 +233,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
@ -370,6 +372,10 @@ class Attention(nn.Module, AttentionLayerBase):
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
if output_shape is None:
output_shape = torch.Size(
(*query.shape[:-1], self.num_heads * self.head_size_v)
)
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
@ -377,11 +383,11 @@ class Attention(nn.Module, AttentionLayerBase):
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -456,6 +462,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
)

View File

@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
SinkFullAttentionSpec,
)
logger = init_logger(__name__)
@functools.lru_cache
def create_static_sink_attention_backend(
underlying_attn_backend: type[AttentionBackend],
sink_len: int = 0,
) -> type[AttentionBackend]:
prefix = "StaticSink_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.sink_len = sink_len
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
self.sink_block_table = torch.arange(
1,
self.num_sink_blocks + 1,
device=device,
dtype=torch.int32,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata.seq_lens[:] = (
common_attn_metadata.seq_lens + self.sink_len
)
common_attn_metadata.seq_lens_cpu = (
common_attn_metadata.seq_lens_cpu + self.sink_len
)
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
blk_table_tensor = common_attn_metadata.block_table_tensor
sink_block_table = self.sink_block_table[None, :].expand(
blk_table_tensor.shape[0], -1
)
blk_table_tensor_clone = blk_table_tensor.clone()
blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[
:, : -self.num_sink_blocks
]
blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table
return super().build(common_prefix_len, common_attn_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=StaticSinkAttentionBuilder,
)
return attn_backend
class StaticSinkAttention(Attention):
"""
Attention with static sink tokens
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
sink_len: int,
attn_backend: type[AttentionBackend] | None = None,
cache_config: CacheConfig | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend,
sink_len=sink_len,
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
**kwargs,
)
self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None
def update_sink_kv(self, sink_key, sink_value) -> None:
self.sink_key = sink_key
self.sink_value = sink_value
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.size | None = None,
) -> torch.Tensor:
assert self.sink_key is not None and self.sink_value is not None, (
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
forward_context: ForwardContext = get_forward_context()
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape)
def populate_sink_kv(self, self_kv_cache):
sink_kv_slot_mapping = torch.arange(
self.block_size,
self.sink_len + self.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
self.sink_key,
self.sink_value,
self_kv_cache,
sink_kv_slot_mapping,
self.kv_cache_dtype,
self._k_scale,
self._v_scale,
)
# We only populate the sink_key and sink_value once
self.sink_populated = True
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype,
)
def maybe_populate_sink(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0:
return
self.populate_sink_kv(self_kv_cache)
def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_populate_sink",
op_func=maybe_populate_sink,
mutates_args=["self_kv_cache"],
fake_impl=maybe_populate_sink_fake,
)

View File

@ -186,17 +186,20 @@ def triton_reshape_and_cache_flash(
@triton.jit
def reshape_and_cache_kernel_flash_diffkv(
kv_ptr, # [num_tokens, num_heads, head_size + head_size_v]
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size_v]
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
kv_stride: tl.int64,
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size_kv: tl.constexpr,
head_size_k: tl.constexpr,
head_size_v: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
@ -211,24 +214,51 @@ def reshape_and_cache_kernel_flash_diffkv(
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_kv_idx = token_idx * kv_stride
src_key_idx = token_idx * key_stride + tile_i * head_size_k
src_value_idx = token_idx * value_stride + tile_i * head_size_v
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
kv_tile = tl.load(
kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv)
tgt_idx = (
block_idx * block_stride
+ block_offset * page_stride
+ tile_i * (head_size_k + head_size_v)
)
# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
if FP8_KV_CACHE:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(
value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v
)
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
kv_cache_ptr + tgt_idx + tile_pos,
kv_tile,
mask=tile_pos < (num_heads * head_size_kv),
kv_cache_ptr + tgt_idx + tile_offs,
key_tile,
mask=tile_offs < head_size_k,
)
tl.store(
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
value_tile,
mask=tile_offs < head_size_v,
)
return
@ -243,13 +273,13 @@ def triton_reshape_and_cache_flash_diffkv(
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
kv = torch.cat([key, value], dim=-1).contiguous()
num_heads = kv.shape[1]
head_size_kv = kv.shape[2]
num_heads = key.shape[1]
head_size_k = key.shape[2]
head_size_v = value.shape[2]
block_size = kv_cache.shape[1]
n = num_heads * head_size_kv
kv_stride = kv.stride()[0]
k_stride = key.stride()[0]
v_stride = value.stride()[0]
block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1]
@ -272,13 +302,20 @@ def triton_reshape_and_cache_flash_diffkv(
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert not FP8_KV_CACHE, (
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
torch.float8_e4m3fnuz,
], (
"unsupported dtype of KV cache tensor, got "
"{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32."
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
)
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
TILE_SIZE = max(head_size_k, head_size_v)
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
num_warps = 8
@ -292,21 +329,24 @@ def triton_reshape_and_cache_flash_diffkv(
# using cudagraphs
grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
num_heads,
)
reshape_and_cache_kernel_flash_diffkv[grid](
kv_ptr=kv,
key_ptr=key,
value_ptr=value,
kv_cache_ptr=kv_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
kv_stride=kv_stride,
key_stride=k_stride,
value_stride=v_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size_kv=head_size_kv,
head_size_k=head_size_k,
head_size_v=head_size_v,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,

View File

@ -29,8 +29,8 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layers.static_sink_attention import StaticSinkAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
@ -41,7 +41,6 @@ from vllm.distributed import (
get_tp_group,
tensor_model_parallel_all_gather,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -82,11 +81,7 @@ from vllm.model_executor.models.utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend
from vllm.v1.kv_cache_interface import (
FullDiffkvAttentionSpec,
KVCacheSpec,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
def check_ffn_act_fn(act_fn: str):
@ -96,133 +91,6 @@ def check_ffn_act_fn(act_fn: str):
)
class DiffkvAttention(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
head_size_v: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
cache_config,
quant_config,
logits_soft_cap,
per_layer_sliding_window,
prefix,
attn_type,
kv_sharing_target_layer_name,
attn_backend,
**extra_impl_args,
)
self.head_size_v = head_size_v
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output,
)
else:
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
)
return output.view(-1, hidden_size)
else:
raise ValueError(
"Unsupport Error, currently only flash_diffkv_attn "
"backend with output buffer is supported"
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
# Only support for full attention now.
assert self.sliding_window is None
return FullDiffkvAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
)
class OpenPanguMLP(nn.Module):
def __init__(
self,
@ -673,7 +541,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
)
class OpenPanguDiffkvAttention(nn.Module):
class OpenPanguSinkAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@ -777,19 +645,19 @@ class OpenPanguDiffkvAttention(nn.Module):
else:
sliding_window = None
FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels)
self.attn = DiffkvAttention(
self.attn = StaticSinkAttention(
self.num_heads,
self.head_dim,
self.v_channels,
self.scaling,
sink_len=self.param_sink_number,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
attn_backend=FlashDiffkvAttentionBackend,
attn_backend=FlashAttentionBackend,
head_size_v=self.v_channels,
)
if self.param_sink_number > 0:
@ -919,19 +787,13 @@ class OpenPanguDiffkvAttention(nn.Module):
is_neox_style=is_neox_style,
)
def get_sink_kv(self) -> dict[str, torch.Tensor]:
if self.param_sink_number == 0:
raise ValueError("No sink_key and sink_value when param_sink_number == 0")
def post_weight_load(self) -> None:
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
param_sink_key = self.k_layernorm(self.param_sink_key)
else:
param_sink_key = self.param_sink_key
return {
"sink_key": param_sink_key,
"sink_value": self.param_sink_value,
}
self.attn.update_sink_kv(param_sink_key, self.param_sink_value)
class OpenPanguDecoderLayer(nn.Module):
@ -1001,7 +863,7 @@ class OpenPanguDecoderLayer(nn.Module):
"rope_type": "default",
"rope_theta": config.rope_theta,
}
self.self_attn = OpenPanguDiffkvAttention(
self.self_attn = OpenPanguSinkAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -1359,8 +1221,17 @@ class OpenPanguModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
self.post_weight_load()
return loaded_params
def post_weight_load(self) -> None:
for name, module in self.named_modules():
if module is self:
continue
if hasattr(module, "post_weight_load"):
module.post_weight_load()
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {

View File

@ -18,6 +18,9 @@ from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8,
get_flash_attn_version,
@ -105,28 +108,48 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
head_size_v: int | None = None,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
if head_size_v is None or head_size == head_size_v:
return (2, num_blocks, block_size, num_kv_heads, head_size)
else:
return (
num_blocks,
block_size,
num_kv_heads,
head_size + head_size_v,
)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
diff_kv: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
if not diff_kv:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
else:
# (num_blocks, num_layers, block_size,
# num_kv_heads, head_size + head_size_v)
return (0, 1, 2, 3, 4)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
if not diff_kv:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
else:
# (num_blocks, num_kv_heads, num_layers,
# block_size, head_size + head_size_v)
return (2, 3, 0, 1, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@ -576,11 +599,14 @@ class FlashAttentionImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
or [num_tokens, num_kv_heads, head_size_v]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
or [num_blocks, block_size, num_kv_heads, head_size + head_size_v]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
or [num_tokens, num_heads * head_size_v]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
@ -623,7 +649,13 @@ class FlashAttentionImpl(AttentionImpl):
)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
if self.head_size == kv_cache.shape[-1]:
# Same head_size for K and V
key_cache, value_cache = kv_cache.unbind(0)
else:
# Different head_size for K and V
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
@ -640,16 +672,29 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.head_size == kv_cache.shape[-1]:
# kv_cache update for same head_size K and V
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# kv_cache update for different head_size K and V
triton_reshape_and_cache_flash_diffkv(
key,
value,
kv_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer

File diff suppressed because it is too large Load Diff

View File

@ -210,11 +210,6 @@ class Scheduler(SchedulerInterface):
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
if sink_len > 0:
assert sink_len % self.block_size == 0
num_sink_block = sink_len // self.block_size
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

View File

@ -12,10 +12,10 @@ from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
FullDiffkvAttentionSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request
@ -317,8 +317,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec,
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@ -785,14 +784,35 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
raise NotImplementedError("CrossAttentionManager does not support caching")
class SinkFullAttentionManager(FullAttentionManager):
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
):
super().__init__(
kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size
)
sink_len = kv_cache_spec.sink_len
if sink_len > 0:
assert sink_len % self.block_size == 0
num_sink_block = sink_len // self.block_size
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(
num_sink_block
)
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
FullDiffkvAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
SinkFullAttentionSpec: SinkFullAttentionManager,
}

View File

@ -80,6 +80,7 @@ class AttentionSpec(KVCacheSpec):
@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
head_size_v: int | None = None
sliding_window: int | None = None
attention_chunk_size: int | None = None
"""
@ -92,6 +93,10 @@ class FullAttentionSpec(AttentionSpec):
Default to None for not using sliding window attention.
"""
def __post_init__(self):
if self.head_size_v is None:
object.__setattr__(self, "head_size_v", self.head_size)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
@ -124,88 +129,6 @@ class FullAttentionSpec(AttentionSpec):
"All attention layers in the same KV cache group must be FullAttentionSpec."
)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
attention_chunk_size = set(
spec.attention_chunk_size
for spec in specs
if spec.attention_chunk_size is not None
)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec."
)
assert (merged_spec.sliding_window is not None) + (
merged_spec.attention_chunk_size is not None
) <= 1, (
"Model with both sliding window layers and chunked local attention "
"layers is not supported."
)
return merged_spec
@dataclass(frozen=True)
class FullDiffkvAttentionSpec(AttentionSpec):
head_size_v: int
sliding_window: int | None = None
attention_chunk_size: int | None = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullDiffkvAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size."
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullDiffkvAttentionSpec objects into a single
FullDiffkvAttentionSpec object.
"""
assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullDiffkvAttentionSpec."
)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
@ -376,6 +299,56 @@ class CrossAttentionSpec(AttentionSpec):
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@dataclass(forzen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
sink_len: int | None = None
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be FullAttentionSpec."
)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
attention_chunk_size = set(
spec.attention_chunk_size
for spec in specs
if spec.attention_chunk_size is not None
)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
sink_len=specs[0].sink_len,
dtype=specs[0].dtype,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec."
)
assert (merged_spec.sliding_window is not None) + (
merged_spec.attention_chunk_size is not None
) <= 1, (
"Model with both sliding window layers and chunked local attention "
"layers is not supported."
)
return merged_spec
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
"""

View File

@ -263,6 +263,7 @@ class MultiGroupBlockTable:
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
sink_len: int = 0,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
@ -292,7 +293,7 @@ class MultiGroupBlockTable:
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * total_cp_world_size),
cdiv(max_model_len + sink_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_batched_tokens,

View File

@ -101,16 +101,25 @@ def _reshape_kv_cache(
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
if hasattr(kv_cache_spec, "head_size_v"):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
**kwargs,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
**stride_kwargs
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

View File

@ -143,7 +143,7 @@ class InputBatch:
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len + sink_len,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
@ -151,6 +151,7 @@ class InputBatch:
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
sink_len=sink_len,
)
# Sampling-related.

View File

@ -27,9 +27,6 @@ from vllm.attention.backends.abstract import (
MultipleOf,
)
from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -5209,16 +5206,25 @@ class GPUModelRunner(
)
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
if hasattr(kv_cache_spec, "head_size_v"):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=self.cache_config.cache_dtype,
**kwargs,
)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
**stride_kwargs
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@ -5410,7 +5416,6 @@ class GPUModelRunner(
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)
self.prepare_sink_kv_cache(kv_caches)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
@ -5501,36 +5506,3 @@ class GPUModelRunner(
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
def prepare_sink_kv_cache(self, kv_caches) -> None:
if self.sink_len == 0:
return
def find_module_by_name(model, target_name: str):
for name, module in model.named_modules():
if name == target_name:
return module
raise KeyError(f"Module '{target_name}' not found")
for layer_name, kv_cache in kv_caches.items():
layer_prefix = layer_name.rsplit(".", 1)[0]
self_attn_module = find_module_by_name(self.model, layer_prefix)
if not hasattr(self_attn_module, "get_sink_kv"):
continue
else:
sink_kv = self_attn_module.get_sink_kv()
sink_kv_slot_mapping = torch.arange(
self.vllm_config.cache_config.block_size,
self.sink_len + self.vllm_config.cache_config.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
sink_kv["sink_key"],
sink_kv["sink_value"],
kv_cache,
sink_kv_slot_mapping,
self_attn_module.attn.kv_cache_dtype,
self_attn_module.attn._k_scale,
self_attn_module.attn._v_scale,
)

View File

@ -190,17 +190,25 @@ class KVConnectorModelRunnerMixin:
return False
attn_backend = attn_group.backend
if hasattr(kv_cache_spec, "head_size_v"):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
**kwargs,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
include_num_layers_dimension=True,
**stride_kwargs,
)
except (AttributeError, NotImplementedError):
return False
@ -257,12 +265,19 @@ class KVConnectorModelRunnerMixin:
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
attn_backend = attn_group.backend
if hasattr(kv_cache_spec, "head_size_v"):
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
stride_kwargs = {"diff_kv": True}
else:
kwargs = {}
stride_kwargs = {}
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
**kwargs,
)
# prepend a num_layers dimension into the shape
@ -270,7 +285,8 @@ class KVConnectorModelRunnerMixin:
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
include_num_layers_dimension=True,
**stride_kwargs,
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):