Refactor code, make attn backend focus on diffkv and move sink logic to GPUModelRunner

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-11-26 11:34:23 +08:00
parent dff2694aad
commit 315e3f654a
8 changed files with 277 additions and 205 deletions

View File

@ -42,8 +42,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASH_SINK_ATTN = (
"vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend"
FLASH_DIFFKV_ATTN = (
"vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"

View File

@ -285,8 +285,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
backend_name = self.attn_backend.get_name()
self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
@ -902,20 +901,10 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
kwargs = {}
if sink_key is not None or sink_value is not None:
assert sink_key is not None and sink_value is not None, (
"Currently, it is only supported when "
"sink_key and sink_value are both not None"
)
kwargs["sink_key"] = sink_key
kwargs["sink_value"] = sink_value
self.impl.forward(
self,
@ -927,7 +916,6 @@ def unified_attention_with_output(
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
**kwargs,
)
@ -937,8 +925,6 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:

View File

@ -81,12 +81,12 @@ from vllm.model_executor.models.utils import (
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend
from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend
from vllm.v1.kv_cache_interface import (
FullSinkAttentionSpec,
FullDiffkvAttentionSpec,
KVCacheSpec,
)
from vllm.transformers_utils.config import set_default_rope_theta
def check_ffn_act_fn(act_fn: str):
@ -96,7 +96,7 @@ def check_ffn_act_fn(act_fn: str):
)
class AttentionWithSink(Attention):
class DiffkvAttention(Attention):
def __init__(
self,
num_heads: int,
@ -138,9 +138,6 @@ class AttentionWithSink(Attention):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For attention with sink, we have sink k, v
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
@ -194,8 +191,6 @@ class AttentionWithSink(Attention):
self_kv_cache,
attn_metadata,
output=output,
sink_key=sink_key,
sink_value=sink_value,
)
else:
torch.ops.vllm.unified_attention_with_output(
@ -204,13 +199,11 @@ class AttentionWithSink(Attention):
value,
output,
self.layer_name,
sink_key=sink_key,
sink_value=sink_value,
)
return output.view(-1, hidden_size)
else:
raise ValueError(
"Unsupport Error, currently only flash_sink_attn "
"Unsupport Error, currently only flash_diffkv_attn "
"backend with output buffer is supported"
)
@ -221,7 +214,7 @@ class AttentionWithSink(Attention):
assert self.attn_type == AttentionType.DECODER
# Only support for full attention now.
assert self.sliding_window is None
return FullSinkAttentionSpec(
return FullDiffkvAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
@ -682,15 +675,14 @@ class OpenPanguEmbeddedAttention(nn.Module):
)
class OpenPanguSinkAttention(nn.Module):
class OpenPanguDiffkvAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
@ -739,7 +731,6 @@ class OpenPanguSinkAttention(nn.Module):
self.k_size = self.num_kv_heads * self.head_dim
self.v_size = self.num_kv_heads * self.v_channels
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.param_sink_number = getattr(config, "param_sink_number", 0)
@ -770,7 +761,7 @@ class OpenPanguSinkAttention(nn.Module):
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self._init_rotary_emb(
config, rope_scaling=rope_scaling, quant_config=quant_config
config, rope_parameters=rope_parameters, quant_config=quant_config
)
if hasattr(config, "interleaved_sliding_window"):
@ -788,10 +779,8 @@ class OpenPanguSinkAttention(nn.Module):
else:
sliding_window = None
FlashSinkAttentionBackend.set_cache_head_size_ratio(
(self.head_dim + self.v_channels) / self.head_dim
)
self.attn = AttentionWithSink(
FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels)
self.attn = DiffkvAttention(
self.num_heads,
self.head_dim,
self.v_channels,
@ -802,7 +791,7 @@ class OpenPanguSinkAttention(nn.Module):
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
attn_backend=FlashSinkAttentionBackend,
attn_backend=FlashDiffkvAttentionBackend,
)
if self.param_sink_number > 0:
@ -904,13 +893,6 @@ class OpenPanguSinkAttention(nn.Module):
q = q.view(-1, self.q_size)
k = k.view(-1, self.k_size)
param_sink_key = self.param_sink_key
if (
self.param_sink_number > 0
and hasattr(self, "k_layernorm")
and self.k_layernorm is not None
):
param_sink_key = self.k_layernorm(param_sink_key)
attn_output = self.attn(
q,
@ -919,23 +901,14 @@ class OpenPanguSinkAttention(nn.Module):
output_shape=torch.Size(
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
),
**(
dict(
sink_key=param_sink_key,
sink_value=self.param_sink_value,
)
if self.param_sink_number > 0
else {}
),
)
attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(
self,
config: PretrainedConfig,
rope_scaling: dict[str, Any] | None,
rope_parameters: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = False
@ -944,11 +917,24 @@ class OpenPanguSinkAttention(nn.Module):
self.head_dim,
rotary_dim=self.qk_rope_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
)
def get_sink_kv(self) -> dict[str, torch.Tensor]:
if self.param_sink_number == 0:
raise ValueError("No sink_key and sink_value when param_sink_number == 0")
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
param_sink_key = self.k_layernorm(self.param_sink_key)
else:
param_sink_key = self.param_sink_key
return {
"sink_key": param_sink_key,
"sink_value": self.param_sink_value,
}
class OpenPanguDecoderLayer(nn.Module):
def __init__(
@ -1011,15 +997,20 @@ class OpenPanguDecoderLayer(nn.Module):
f"is_causal={config.is_causal} is not support "
"for attention with sink"
)
self.self_attn = OpenPanguSinkAttention(
rope_parameters = getattr(config, "rope_scaling", None)
if rope_parameters is None:
rope_parameters = {
"rope_type": "default",
"rope_theta": config.rope_theta,
}
self.self_attn = OpenPanguDiffkvAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=getattr(config, "rope_scaling", None),
rope_parameters=rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,

View File

@ -16,6 +16,7 @@ from vllm.attention.backends.abstract import (
is_quantized_kv_cache,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
@ -32,8 +33,9 @@ if is_flash_attn_varlen_func_available():
flash_attn_varlen_func,
get_scheduler_metadata,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@ -54,24 +56,39 @@ from .flash_attn import FlashAttentionMetadata
logger = init_logger(__name__)
class FlashSinkAttentionBackend(AttentionBackend):
class FlashDiffkvAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
# TODO: Remove hard code
cache_head_size_ratio: float = 2.0
head_size_v: int = 128
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if (
model_config
and model_config.is_hybrid
and (
cache_config.mamba_ssm_cache_dtype == "float32"
or cache_config.mamba_cache_dtype == "float32"
)
):
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
return [MultipleOf(16)]
@staticmethod
def get_name() -> str:
return "FLASH_SINK_ATTN"
return "FLASH_DIFFKV_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashSinkAttention supports all attention types."""
"""FlashDiffkvAttention supports all attention types."""
from vllm.attention import AttentionType
return attn_type in (
@ -82,16 +99,16 @@ class FlashSinkAttentionBackend(AttentionBackend):
)
@staticmethod
def get_impl_cls() -> type["FlashSinkAttentionImpl"]:
return FlashSinkAttentionImpl
def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]:
return FlashDiffkvAttentionImpl
@staticmethod
def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]:
return FlashSinkAttentionMetadataBuilder
def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]:
return FlashDiffkvAttentionMetadataBuilder
@classmethod
def set_cache_head_size_ratio(cls, ratio: float) -> None:
cls.cache_head_size_ratio = ratio
def set_head_size_v(cls, head_size_v: int) -> None:
cls.head_size_v = head_size_v
@staticmethod
def get_kv_cache_shape(
@ -107,16 +124,24 @@ class FlashSinkAttentionBackend(AttentionBackend):
num_blocks,
block_size,
num_kv_heads,
int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio),
head_size + FlashDiffkvAttentionBackend.head_size_v,
)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (0, 1, 2, 3, 4)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 3, 0, 1, 4)
elif cache_layout == "HND":
stride_order = (0, 2, 1, 3)
else:
@ -131,8 +156,8 @@ class FlashSinkAttentionBackend(AttentionBackend):
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
def supports_head_size(cls, head_size: int) -> bool:
return head_size % 8 == 0 and head_size <= 256
@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
@ -176,12 +201,12 @@ def _get_sliding_window_configs(
sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashSinkAttentionImpl)
assert isinstance(layer.impl, FlashDiffkvAttentionImpl)
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs
class FlashSinkAttentionMetadataBuilder(
class FlashDiffkvAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]
):
# FA3:
@ -242,8 +267,8 @@ class FlashSinkAttentionMetadataBuilder(
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
self.cp_kv_cache_interleave_size = (
self.parallel_config.cp_kv_cache_interleave_size
)
self.use_full_cuda_graph = (
@ -322,7 +347,7 @@ class FlashSinkAttentionMetadataBuilder(
):
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn(
qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype
)
else:
@ -365,7 +390,7 @@ class FlashSinkAttentionMetadataBuilder(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
self.cp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
@ -450,7 +475,7 @@ class FlashSinkAttentionMetadataBuilder(
return use_cascade_attention(*args, **kwargs)
class FlashSinkAttentionImpl(AttentionImpl):
class FlashDiffkvAttentionImpl(AttentionImpl):
can_return_lse_for_decode: bool = True
def __init__(
@ -466,11 +491,9 @@ class FlashSinkAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
head_size_v: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = head_size_v
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
@ -512,7 +535,7 @@ class FlashSinkAttentionImpl(AttentionImpl):
)
def supports_quant_query_input(self) -> bool:
return False
return True
def forward(
self,
@ -525,10 +548,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashSinkAttention.
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
@ -537,8 +558,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
kv_cache: shape =
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
attn_metadata: Metadata for attention.
sink_key: shape = [sink_len, num_kv_heads, head_size]
sink_value: shape = [sink_len, num_kv_heads, head_size_v]
Returns:
shape = [num_tokens, num_heads * head_size_v]
NOTE: FP8 quantization, flash-attn expect the size of
@ -546,14 +565,11 @@ class FlashSinkAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert sink_key is not None and sink_value is not None, (
"sink_key and sink_value must be provided"
)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not supported yet "
"for FlashSinkAttentionImpl"
"fused output quantization is not yet supported for"
"FlashDiffkvAttentionImpl"
)
if attn_metadata is None:
@ -572,7 +588,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
sink_len = sink_key.shape[0]
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
@ -585,11 +600,11 @@ class FlashSinkAttentionImpl(AttentionImpl):
output[:num_actual_tokens],
attn_metadata,
layer,
sink_key,
sink_value,
)
# For decoder and cross-attention, use KV cache as before
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
@ -606,29 +621,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# store sink_key and sink_value in head blocks
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
block_size = key_cache.shape[1]
assert sink_len % block_size == 0
num_sink_blocks = sink_len // block_size
sink_kv_slot_mapping = torch.arange(
block_size,
sink_len + block_size,
device=attn_metadata.slot_mapping.device,
dtype=attn_metadata.slot_mapping.dtype,
)
triton_reshape_and_cache_flash_diffkv(
sink_key,
sink_value,
kv_cache,
sink_kv_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
triton_reshape_and_cache_flash_diffkv(
key,
value,
@ -641,7 +633,7 @@ class FlashSinkAttentionImpl(AttentionImpl):
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn(
dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype
)
key_cache = key_cache.view(dtype)
@ -649,29 +641,28 @@ class FlashSinkAttentionImpl(AttentionImpl):
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens + sink_len
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len + sink_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
sink_block_table = torch.arange(
1,
num_sink_blocks + 1,
device=block_table.device,
dtype=block_table.dtype,
)
sink_block_table = sink_block_table[None, :].expand(
block_table.shape[0], -1
)
block_table = torch.cat((sink_block_table, block_table), dim=1)
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if self.dcp_world_size > 1:
raise ValueError(
"Decode context parallel is not supported yet "
f"for dcp_world_size = {self.dcp_world_size}"
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
@ -724,10 +715,89 @@ class FlashSinkAttentionImpl(AttentionImpl):
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
sink_len=sink_len,
)
return output
def _forward_with_dcp(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
q_descale: torch.Tensor | None = None,
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
block_table = attn_metadata.block_table
query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
context_attn_out,
context_lse.transpose(0, 1),
get_dcp_group(),
return_lse=True,
)
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
query_attn_out, query_lse = flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
merge_attn_states(
output,
context_attn_out_cor,
context_lse_cor,
query_attn_out,
query_lse,
)
def _forward_encoder_attention(
self,
query: torch.Tensor,
@ -736,20 +806,16 @@ class FlashSinkAttentionImpl(AttentionImpl):
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
sink_key: torch.Tensor,
sink_value: torch.Tensor,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size_v]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
sink_key: shape = [sink_len, num_kv_heads, head_size]
sink_value: shape = [sink_len, num_kv_heads, head_size_v]
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
@ -758,40 +824,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
)
# Use encoder-specific metadata for sequence information
sink_len = sink_key.shape[0]
key_list = []
value_list = []
for seq_id in range(attn_metadata.block_table.shape[0]):
seq_start = attn_metadata.query_start_loc[seq_id]
seq_end = attn_metadata.query_start_loc[seq_id + 1]
key_list.append(
torch.cat(
[
sink_key,
key[seq_start:seq_end],
],
dim=0,
)
)
value_list.append(
torch.cat(
[
sink_value,
value[seq_start:seq_end],
],
dim=0,
)
)
key = torch.cat(key_list, dim=0).contiguous()
value = torch.cat(value_list, dim=0).contiguous()
cu_seqlens_q = attn_metadata.query_start_loc
cu_seqlens_k = attn_metadata.seq_lens + sink_len
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0
).to(torch.int32)
cu_seqlens_k = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len + sink_len
max_seqlen_k = attn_metadata.max_query_len
descale_shape = (
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
@ -926,14 +962,12 @@ def cascade_attention(
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
s_aux: torch.Tensor | None = None,
sink_len: int | None = None,
) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window."
)
assert sink_len is not None, "sink_len must be provided."
num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
@ -942,22 +976,15 @@ def cascade_attention(
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
num_sink_blocks = sink_len // block_size
sink_block_table = torch.arange(
1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
)
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
block_table = torch.cat((sink_block_table, block_table), dim=1)
# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
seqused_k=prefix_kv_lens + sink_len,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len + sink_len,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
@ -989,7 +1016,7 @@ def cascade_attention(
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :],
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,

View File

@ -12,7 +12,7 @@ from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
FullSinkAttentionSpec,
FullDiffkvAttentionSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
@ -311,7 +311,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
(FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec),
(FullAttentionSpec, FullDiffkvAttentionSpec, ChunkedLocalAttentionSpec),
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@ -733,7 +733,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
FullSinkAttentionSpec: FullAttentionManager,
FullDiffkvAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,

View File

@ -159,7 +159,7 @@ class FullAttentionSpec(AttentionSpec):
@dataclass(frozen=True)
class FullSinkAttentionSpec(AttentionSpec):
class FullDiffkvAttentionSpec(AttentionSpec):
head_size_v: int
sliding_window: int | None = None
attention_chunk_size: int | None = None
@ -170,7 +170,7 @@ class FullSinkAttentionSpec(AttentionSpec):
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullSinkAttentionSpec and record the sliding window size.
In this case, we use FullDiffkvAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
@ -198,12 +198,12 @@ class FullSinkAttentionSpec(AttentionSpec):
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullSinkAttentionSpec objects into a single
FullSinkAttentionSpec object.
Merge a list of FullDiffkvAttentionSpec objects into a single
FullDiffkvAttentionSpec object.
"""
assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), (
assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullSinkAttentionSpec."
"FullDiffkvAttentionSpec."
)
sliding_window = set(

View File

@ -90,6 +90,7 @@ class InputBatch:
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
sink_len: int = 0,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@ -136,7 +137,7 @@ class InputBatch:
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_model_len=max_model_len + sink_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,

View File

@ -25,6 +25,9 @@ from vllm.attention.backends.abstract import (
AttentionMetadata,
MultipleOf,
)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -321,6 +324,10 @@ class GPUModelRunner(
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.sink_len = getattr(
self.vllm_config.model_config.hf_config, "param_sink_number", 0
)
assert self.sink_len % self.cache_config.block_size == 0
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = model_config.uses_alibi
@ -443,6 +450,7 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
sink_len=self.sink_len,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
@ -1590,6 +1598,28 @@ class GPUModelRunner(
# graph mode.
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
# Modify the blk_table_tensor and seq_lens in-place so that attention will
# know there are sink_key and sink_value in kv_caches
if self.sink_len > 0:
seq_lens[:] = seq_lens + self.sink_len
seq_lens_cpu[:] = seq_lens_cpu + self.sink_len
max_seq_len = max_seq_len + self.sink_len
sink_block_table = torch.arange(
1,
self.sink_len // self.cache_config.block_size + 1,
device=blk_table_tensor.device,
dtype=blk_table_tensor.dtype,
)
sink_block_table = sink_block_table[None, :].expand(
blk_table_tensor.shape[0], -1
)
num_sink_blocks = sink_block_table.shape[1]
blk_table_tensor_clone = blk_table_tensor.clone()
blk_table_tensor[:, num_sink_blocks:] = blk_table_tensor_clone[
:, :-num_sink_blocks
]
blk_table_tensor[:, :num_sink_blocks] = sink_block_table
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
@ -1624,6 +1654,8 @@ class GPUModelRunner(
if cascade_attn_prefix_lens
else 0
)
if self.sink_len > 0:
cascade_attn_prefix_len = cascade_attn_prefix_len + self.sink_len
builder = attn_group.get_metadata_builder()
extra_attn_metadata_args = {}
@ -4838,6 +4870,7 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=self.num_spec_tokens,
sink_len=self.sink_len,
)
def _allocate_kv_cache_tensors(
@ -5165,6 +5198,7 @@ class GPUModelRunner(
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config, kernel_block_sizes
)
self.prepare_sink_kv_cache(kv_caches)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
@ -5269,3 +5303,36 @@ class GPUModelRunner(
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
def prepare_sink_kv_cache(self, kv_caches) -> None:
if self.sink_len == 0:
return
def find_module_by_name(model, target_name: str):
for name, module in model.named_modules():
if name == target_name:
return module
raise KeyError(f"Module '{target_name}' not found")
for layer_name, kv_cache in kv_caches.item():
layer_prefix = layer_name.rsplit(".", 1)[0]
self_attn_module = find_module_by_name(self.model, layer_prefix)
if not hasattr(self_attn_module, "get_sink_kv"):
continue
else:
sink_kv = self_attn_module.get_sink_kv()
sink_kv_slot_mapping = torch.arange(
self.vllm_config.cache_config.block_size,
self.sink_len + self.vllm_config.cache_config.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
sink_kv["sink_key"],
sink_kv["sink_value"],
kv_cache,
sink_kv_slot_mapping,
self_attn_module.attn.kv_cache_dtype,
self_attn_module.attn._k_scale,
self_attn_module.attn._v_scale,
)