mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 12:27:11 +08:00
Refactor code, make attn backend focus on diffkv and move sink logic to GPUModelRunner
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
dff2694aad
commit
315e3f654a
@ -42,8 +42,8 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
FLASH_SINK_ATTN = (
|
||||
"vllm.v1.attention.backends.flash_sink_attn.FlashSinkAttentionBackend"
|
||||
FLASH_DIFFKV_ATTN = (
|
||||
"vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend"
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
|
||||
@ -285,8 +285,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
backend_name = self.attn_backend.get_name()
|
||||
self.backend = AttentionBackendEnum.__members__.get(backend_name)
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@ -902,20 +901,10 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
kwargs = {}
|
||||
if sink_key is not None or sink_value is not None:
|
||||
assert sink_key is not None and sink_value is not None, (
|
||||
"Currently, it is only supported when "
|
||||
"sink_key and sink_value are both not None"
|
||||
)
|
||||
kwargs["sink_key"] = sink_key
|
||||
kwargs["sink_value"] = sink_value
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
@ -927,7 +916,6 @@ def unified_attention_with_output(
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -937,8 +925,6 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
|
||||
@ -81,12 +81,12 @@ from vllm.model_executor.models.utils import (
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullSinkAttentionSpec,
|
||||
FullDiffkvAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
|
||||
|
||||
def check_ffn_act_fn(act_fn: str):
|
||||
@ -96,7 +96,7 @@ def check_ffn_act_fn(act_fn: str):
|
||||
)
|
||||
|
||||
|
||||
class AttentionWithSink(Attention):
|
||||
class DiffkvAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -138,9 +138,6 @@ class AttentionWithSink(Attention):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For attention with sink, we have sink k, v
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -194,8 +191,6 @@ class AttentionWithSink(Attention):
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
@ -204,13 +199,11 @@ class AttentionWithSink(Attention):
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupport Error, currently only flash_sink_attn "
|
||||
"Unsupport Error, currently only flash_diffkv_attn "
|
||||
"backend with output buffer is supported"
|
||||
)
|
||||
|
||||
@ -221,7 +214,7 @@ class AttentionWithSink(Attention):
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
# Only support for full attention now.
|
||||
assert self.sliding_window is None
|
||||
return FullSinkAttentionSpec(
|
||||
return FullDiffkvAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
@ -682,15 +675,14 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguSinkAttention(nn.Module):
|
||||
class OpenPanguDiffkvAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
@ -739,7 +731,6 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
self.k_size = self.num_kv_heads * self.head_dim
|
||||
self.v_size = self.num_kv_heads * self.v_channels
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.param_sink_number = getattr(config, "param_sink_number", 0)
|
||||
@ -770,7 +761,7 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self._init_rotary_emb(
|
||||
config, rope_scaling=rope_scaling, quant_config=quant_config
|
||||
config, rope_parameters=rope_parameters, quant_config=quant_config
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
@ -788,10 +779,8 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
FlashSinkAttentionBackend.set_cache_head_size_ratio(
|
||||
(self.head_dim + self.v_channels) / self.head_dim
|
||||
)
|
||||
self.attn = AttentionWithSink(
|
||||
FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels)
|
||||
self.attn = DiffkvAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.v_channels,
|
||||
@ -802,7 +791,7 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend=FlashSinkAttentionBackend,
|
||||
attn_backend=FlashDiffkvAttentionBackend,
|
||||
)
|
||||
|
||||
if self.param_sink_number > 0:
|
||||
@ -904,13 +893,6 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
|
||||
q = q.view(-1, self.q_size)
|
||||
k = k.view(-1, self.k_size)
|
||||
param_sink_key = self.param_sink_key
|
||||
if (
|
||||
self.param_sink_number > 0
|
||||
and hasattr(self, "k_layernorm")
|
||||
and self.k_layernorm is not None
|
||||
):
|
||||
param_sink_key = self.k_layernorm(param_sink_key)
|
||||
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
@ -919,23 +901,14 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
output_shape=torch.Size(
|
||||
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
|
||||
),
|
||||
**(
|
||||
dict(
|
||||
sink_key=param_sink_key,
|
||||
sink_value=self.param_sink_value,
|
||||
)
|
||||
if self.param_sink_number > 0
|
||||
else {}
|
||||
),
|
||||
)
|
||||
attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rope_scaling: dict[str, Any] | None,
|
||||
rope_parameters: dict[str, Any] | None,
|
||||
quant_config: QuantizationConfig | None,
|
||||
) -> None:
|
||||
is_neox_style = False
|
||||
@ -944,11 +917,24 @@ class OpenPanguSinkAttention(nn.Module):
|
||||
self.head_dim,
|
||||
rotary_dim=self.qk_rope_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
def get_sink_kv(self) -> dict[str, torch.Tensor]:
|
||||
if self.param_sink_number == 0:
|
||||
raise ValueError("No sink_key and sink_value when param_sink_number == 0")
|
||||
|
||||
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
|
||||
param_sink_key = self.k_layernorm(self.param_sink_key)
|
||||
else:
|
||||
param_sink_key = self.param_sink_key
|
||||
|
||||
return {
|
||||
"sink_key": param_sink_key,
|
||||
"sink_value": self.param_sink_value,
|
||||
}
|
||||
|
||||
|
||||
class OpenPanguDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
@ -1011,15 +997,20 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
f"is_causal={config.is_causal} is not support "
|
||||
"for attention with sink"
|
||||
)
|
||||
self.self_attn = OpenPanguSinkAttention(
|
||||
rope_parameters = getattr(config, "rope_scaling", None)
|
||||
if rope_parameters is None:
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": config.rope_theta,
|
||||
}
|
||||
self.self_attn = OpenPanguDiffkvAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
rope_parameters=rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.attention.backends.abstract import (
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
@ -32,8 +33,9 @@ if is_flash_attn_varlen_func_available():
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -54,24 +56,39 @@ from .flash_attn import FlashAttentionMetadata
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashSinkAttentionBackend(AttentionBackend):
|
||||
class FlashDiffkvAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
|
||||
# TODO: Remove hard code
|
||||
cache_head_size_ratio: float = 2.0
|
||||
head_size_v: int = 128
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
if (
|
||||
model_config
|
||||
and model_config.is_hybrid
|
||||
and (
|
||||
cache_config.mamba_ssm_cache_dtype == "float32"
|
||||
or cache_config.mamba_cache_dtype == "float32"
|
||||
)
|
||||
):
|
||||
# NOTE(tdoublep): while in principle, FA supports
|
||||
# MultipleOf(16), these are the block sizes that do not
|
||||
# suffer from the NaN propagation problem described here:
|
||||
# https://github.com/Dao-AILab/flash-attention/issues/1974
|
||||
return [16, 32, 64]
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_SINK_ATTN"
|
||||
return "FLASH_DIFFKV_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlashSinkAttention supports all attention types."""
|
||||
"""FlashDiffkvAttention supports all attention types."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
@ -82,16 +99,16 @@ class FlashSinkAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashSinkAttentionImpl"]:
|
||||
return FlashSinkAttentionImpl
|
||||
def get_impl_cls() -> type["FlashDiffkvAttentionImpl"]:
|
||||
return FlashDiffkvAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]:
|
||||
return FlashSinkAttentionMetadataBuilder
|
||||
def get_builder_cls() -> type["FlashDiffkvAttentionMetadataBuilder"]:
|
||||
return FlashDiffkvAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def set_cache_head_size_ratio(cls, ratio: float) -> None:
|
||||
cls.cache_head_size_ratio = ratio
|
||||
def set_head_size_v(cls, head_size_v: int) -> None:
|
||||
cls.head_size_v = head_size_v
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
@ -107,16 +124,24 @@ class FlashSinkAttentionBackend(AttentionBackend):
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio),
|
||||
head_size + FlashDiffkvAttentionBackend.head_size_v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
|
||||
return (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 3, 0, 1, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 2, 1, 3)
|
||||
else:
|
||||
@ -131,8 +156,8 @@ class FlashSinkAttentionBackend(AttentionBackend):
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size % 8 == 0 and head_size <= 256
|
||||
|
||||
@classmethod
|
||||
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
|
||||
@ -176,12 +201,12 @@ def _get_sliding_window_configs(
|
||||
sliding_window_configs: set[tuple[int, int] | None] = set()
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer in layers.values():
|
||||
assert isinstance(layer.impl, FlashSinkAttentionImpl)
|
||||
assert isinstance(layer.impl, FlashDiffkvAttentionImpl)
|
||||
sliding_window_configs.add(layer.impl.sliding_window)
|
||||
return sliding_window_configs
|
||||
|
||||
|
||||
class FlashSinkAttentionMetadataBuilder(
|
||||
class FlashDiffkvAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]
|
||||
):
|
||||
# FA3:
|
||||
@ -242,8 +267,8 @@ class FlashSinkAttentionMetadataBuilder(
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
self.dcp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
self.cp_kv_cache_interleave_size = (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
@ -322,7 +347,7 @@ class FlashSinkAttentionMetadataBuilder(
|
||||
):
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
qkv_dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
cache_dtype
|
||||
)
|
||||
else:
|
||||
@ -365,7 +390,7 @@ class FlashSinkAttentionMetadataBuilder(
|
||||
dcp_context_kv_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
self.cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||
@ -450,7 +475,7 @@ class FlashSinkAttentionMetadataBuilder(
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashSinkAttentionImpl(AttentionImpl):
|
||||
class FlashDiffkvAttentionImpl(AttentionImpl):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
@ -466,11 +491,9 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
head_size_v: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = head_size_v
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
@ -512,7 +535,7 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -525,10 +548,8 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashSinkAttention.
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
@ -537,8 +558,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
|
||||
attn_metadata: Metadata for attention.
|
||||
sink_key: shape = [sink_len, num_kv_heads, head_size]
|
||||
sink_value: shape = [sink_len, num_kv_heads, head_size_v]
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size_v]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
@ -546,14 +565,11 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert sink_key is not None and sink_value is not None, (
|
||||
"sink_key and sink_value must be provided"
|
||||
)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not supported yet "
|
||||
"for FlashSinkAttentionImpl"
|
||||
"fused output quantization is not yet supported for"
|
||||
"FlashDiffkvAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
@ -572,7 +588,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
sink_len = sink_key.shape[0]
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
@ -585,11 +600,11 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
sink_key,
|
||||
sink_value,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
@ -606,29 +621,6 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
|
||||
# store sink_key and sink_value in head blocks
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
block_size = key_cache.shape[1]
|
||||
assert sink_len % block_size == 0
|
||||
num_sink_blocks = sink_len // block_size
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
block_size,
|
||||
sink_len + block_size,
|
||||
device=attn_metadata.slot_mapping.device,
|
||||
dtype=attn_metadata.slot_mapping.dtype,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
sink_key,
|
||||
sink_value,
|
||||
kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
@ -641,7 +633,7 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
dtype = FlashDiffkvAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
key_cache = key_cache.view(dtype)
|
||||
@ -649,29 +641,28 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens + sink_len
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len + sink_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
sink_block_table = torch.arange(
|
||||
1,
|
||||
num_sink_blocks + 1,
|
||||
device=block_table.device,
|
||||
dtype=block_table.dtype,
|
||||
)
|
||||
sink_block_table = sink_block_table[None, :].expand(
|
||||
block_table.shape[0], -1
|
||||
)
|
||||
block_table = torch.cat((sink_block_table, block_table), dim=1)
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
raise ValueError(
|
||||
"Decode context parallel is not supported yet "
|
||||
f"for dcp_world_size = {self.dcp_world_size}"
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
@ -724,10 +715,89 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_with_dcp(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
query = query.contiguous()
|
||||
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||
q=query_across_dcp,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||
context_attn_out,
|
||||
context_lse.transpose(0, 1),
|
||||
get_dcp_group(),
|
||||
return_lse=True,
|
||||
)
|
||||
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
|
||||
|
||||
query_attn_out, query_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
out=None,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_k=cu_seqlens_q,
|
||||
max_seqlen_k=max_seqlen_q,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
softcap=self.logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||
assert context_lse_cor.shape == query_lse.shape
|
||||
merge_attn_states(
|
||||
output,
|
||||
context_attn_out_cor,
|
||||
context_lse_cor,
|
||||
query_attn_out,
|
||||
query_lse,
|
||||
)
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -736,20 +806,16 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
sink_key: torch.Tensor,
|
||||
sink_value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size_v]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
sink_key: shape = [sink_len, num_kv_heads, head_size]
|
||||
sink_value: shape = [sink_len, num_kv_heads, head_size_v]
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@ -758,40 +824,10 @@ class FlashSinkAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
sink_len = sink_key.shape[0]
|
||||
key_list = []
|
||||
value_list = []
|
||||
for seq_id in range(attn_metadata.block_table.shape[0]):
|
||||
seq_start = attn_metadata.query_start_loc[seq_id]
|
||||
seq_end = attn_metadata.query_start_loc[seq_id + 1]
|
||||
key_list.append(
|
||||
torch.cat(
|
||||
[
|
||||
sink_key,
|
||||
key[seq_start:seq_end],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
value_list.append(
|
||||
torch.cat(
|
||||
[
|
||||
sink_value,
|
||||
value[seq_start:seq_end],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
key = torch.cat(key_list, dim=0).contiguous()
|
||||
value = torch.cat(value_list, dim=0).contiguous()
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
cu_seqlens_k = attn_metadata.seq_lens + sink_len
|
||||
cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0
|
||||
).to(torch.int32)
|
||||
cu_seqlens_k = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len + sink_len
|
||||
max_seqlen_k = attn_metadata.max_query_len
|
||||
|
||||
descale_shape = (
|
||||
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
|
||||
@ -926,14 +962,12 @@ def cascade_attention(
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
sink_len: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
|
||||
# TODO: Support sliding window.
|
||||
assert sliding_window == (-1, -1), (
|
||||
"Cascade attention does not support sliding window."
|
||||
)
|
||||
assert sink_len is not None, "sink_len must be provided."
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
block_size = key_cache.shape[-3]
|
||||
@ -942,22 +976,15 @@ def cascade_attention(
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
num_sink_blocks = sink_len // block_size
|
||||
sink_block_table = torch.arange(
|
||||
1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype
|
||||
)
|
||||
sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1)
|
||||
block_table = torch.cat((sink_block_table, block_table), dim=1)
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
seqused_k=prefix_kv_lens + sink_len,
|
||||
seqused_k=prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len + sink_len,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
@ -989,7 +1016,7 @@ def cascade_attention(
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :],
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=suffix_scheduler_metadata,
|
||||
@ -12,7 +12,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
FullSinkAttentionSpec,
|
||||
FullDiffkvAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
@ -311,7 +311,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec,
|
||||
(FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec),
|
||||
(FullAttentionSpec, FullDiffkvAttentionSpec, ChunkedLocalAttentionSpec),
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
@ -733,7 +733,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
FullSinkAttentionSpec: FullAttentionManager,
|
||||
FullDiffkvAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
|
||||
@ -159,7 +159,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullSinkAttentionSpec(AttentionSpec):
|
||||
class FullDiffkvAttentionSpec(AttentionSpec):
|
||||
head_size_v: int
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
@ -170,7 +170,7 @@ class FullSinkAttentionSpec(AttentionSpec):
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullSinkAttentionSpec and record the sliding window size.
|
||||
In this case, we use FullDiffkvAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
@ -198,12 +198,12 @@ class FullSinkAttentionSpec(AttentionSpec):
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullSinkAttentionSpec objects into a single
|
||||
FullSinkAttentionSpec object.
|
||||
Merge a list of FullDiffkvAttentionSpec objects into a single
|
||||
FullDiffkvAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), (
|
||||
assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullSinkAttentionSpec."
|
||||
"FullDiffkvAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
|
||||
@ -90,6 +90,7 @@ class InputBatch:
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
sink_len: int = 0,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@ -136,7 +137,7 @@ class InputBatch:
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_model_len=max_model_len + sink_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
|
||||
@ -25,6 +25,9 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
@ -321,6 +324,10 @@ class GPUModelRunner(
|
||||
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
self.sink_len = getattr(
|
||||
self.vllm_config.model_config.hf_config, "param_sink_number", 0
|
||||
)
|
||||
assert self.sink_len % self.cache_config.block_size == 0
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
self.use_alibi = model_config.uses_alibi
|
||||
|
||||
@ -443,6 +450,7 @@ class GPUModelRunner(
|
||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
|
||||
sink_len=self.sink_len,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
@ -1590,6 +1598,28 @@ class GPUModelRunner(
|
||||
# graph mode.
|
||||
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
||||
|
||||
# Modify the blk_table_tensor and seq_lens in-place so that attention will
|
||||
# know there are sink_key and sink_value in kv_caches
|
||||
if self.sink_len > 0:
|
||||
seq_lens[:] = seq_lens + self.sink_len
|
||||
seq_lens_cpu[:] = seq_lens_cpu + self.sink_len
|
||||
max_seq_len = max_seq_len + self.sink_len
|
||||
sink_block_table = torch.arange(
|
||||
1,
|
||||
self.sink_len // self.cache_config.block_size + 1,
|
||||
device=blk_table_tensor.device,
|
||||
dtype=blk_table_tensor.dtype,
|
||||
)
|
||||
sink_block_table = sink_block_table[None, :].expand(
|
||||
blk_table_tensor.shape[0], -1
|
||||
)
|
||||
num_sink_blocks = sink_block_table.shape[1]
|
||||
blk_table_tensor_clone = blk_table_tensor.clone()
|
||||
blk_table_tensor[:, num_sink_blocks:] = blk_table_tensor_clone[
|
||||
:, :-num_sink_blocks
|
||||
]
|
||||
blk_table_tensor[:, :num_sink_blocks] = sink_block_table
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
@ -1624,6 +1654,8 @@ class GPUModelRunner(
|
||||
if cascade_attn_prefix_lens
|
||||
else 0
|
||||
)
|
||||
if self.sink_len > 0:
|
||||
cascade_attn_prefix_len = cascade_attn_prefix_len + self.sink_len
|
||||
builder = attn_group.get_metadata_builder()
|
||||
|
||||
extra_attn_metadata_args = {}
|
||||
@ -4838,6 +4870,7 @@ class GPUModelRunner(
|
||||
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=self.num_spec_tokens,
|
||||
sink_len=self.sink_len,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
@ -5165,6 +5198,7 @@ class GPUModelRunner(
|
||||
kv_caches = self.initialize_kv_cache_tensors(
|
||||
kv_cache_config, kernel_block_sizes
|
||||
)
|
||||
self.prepare_sink_kv_cache(kv_caches)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
@ -5269,3 +5303,36 @@ class GPUModelRunner(
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return pinned.tolist()
|
||||
|
||||
def prepare_sink_kv_cache(self, kv_caches) -> None:
|
||||
if self.sink_len == 0:
|
||||
return
|
||||
|
||||
def find_module_by_name(model, target_name: str):
|
||||
for name, module in model.named_modules():
|
||||
if name == target_name:
|
||||
return module
|
||||
raise KeyError(f"Module '{target_name}' not found")
|
||||
|
||||
for layer_name, kv_cache in kv_caches.item():
|
||||
layer_prefix = layer_name.rsplit(".", 1)[0]
|
||||
self_attn_module = find_module_by_name(self.model, layer_prefix)
|
||||
if not hasattr(self_attn_module, "get_sink_kv"):
|
||||
continue
|
||||
else:
|
||||
sink_kv = self_attn_module.get_sink_kv()
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.vllm_config.cache_config.block_size,
|
||||
self.sink_len + self.vllm_config.cache_config.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
sink_kv["sink_key"],
|
||||
sink_kv["sink_value"],
|
||||
kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self_attn_module.attn.kv_cache_dtype,
|
||||
self_attn_module.attn._k_scale,
|
||||
self_attn_module.attn._v_scale,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user