mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:27:03 +08:00
Refacotr code. Extent FLASH_ATTN to support different KV size and create a new StaticSinkAttention for sink token logics
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
de538d3b8f
commit
b565203d92
@ -42,9 +42,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
FLASH_DIFFKV_ATTN = (
|
||||
"vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend"
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
|
||||
@ -191,6 +191,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@ -232,6 +233,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
@ -370,6 +372,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
output_shape = torch.Size(
|
||||
(*query.shape[:-1], self.num_heads * self.head_size_v)
|
||||
)
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
@ -377,11 +383,11 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@ -456,6 +462,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
225
vllm/attention/layers/static_sink_attention.py
Normal file
225
vllm/attention/layers/static_sink_attention.py
Normal file
@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
SinkFullAttentionSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_static_sink_attention_backend(
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
sink_len: int = 0,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "StaticSink_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.sink_len = sink_len
|
||||
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||
self.sink_block_table = torch.arange(
|
||||
1,
|
||||
self.num_sink_blocks + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata.seq_lens[:] = (
|
||||
common_attn_metadata.seq_lens + self.sink_len
|
||||
)
|
||||
common_attn_metadata.seq_lens_cpu = (
|
||||
common_attn_metadata.seq_lens_cpu + self.sink_len
|
||||
)
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.max_seq_len + self.sink_len
|
||||
)
|
||||
|
||||
blk_table_tensor = common_attn_metadata.block_table_tensor
|
||||
sink_block_table = self.sink_block_table[None, :].expand(
|
||||
blk_table_tensor.shape[0], -1
|
||||
)
|
||||
blk_table_tensor_clone = blk_table_tensor.clone()
|
||||
blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[
|
||||
:, : -self.num_sink_blocks
|
||||
]
|
||||
blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table
|
||||
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=StaticSinkAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class StaticSinkAttention(Attention):
|
||||
"""
|
||||
Attention with static sink tokens
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
sink_len: int,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_backend is not None:
|
||||
underlying_attn_backend = attn_backend
|
||||
else:
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
self.sink_populated = False
|
||||
self.sink_key = None
|
||||
self.sink_value = None
|
||||
|
||||
def update_sink_kv(self, sink_key, sink_value) -> None:
|
||||
self.sink_key = sink_key
|
||||
self.sink_value = sink_value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.size | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.sink_key is not None and self.sink_value is not None, (
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
def populate_sink_kv(self, self_kv_cache):
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.block_size,
|
||||
self.sink_len + self.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
self.sink_key,
|
||||
self.sink_value,
|
||||
self_kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
)
|
||||
# We only populate the sink_key and sink_value once
|
||||
self.sink_populated = True
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
|
||||
return SinkFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
sink_len=self.sink_len,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_populate_sink(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if self.sink_populated or self_kv_cache.numel() == 0:
|
||||
return
|
||||
self.populate_sink_kv(self_kv_cache)
|
||||
|
||||
|
||||
def maybe_populate_sink_fake(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_populate_sink",
|
||||
op_func=maybe_populate_sink,
|
||||
mutates_args=["self_kv_cache"],
|
||||
fake_impl=maybe_populate_sink_fake,
|
||||
)
|
||||
@ -186,17 +186,20 @@ def triton_reshape_and_cache_flash(
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash_diffkv(
|
||||
kv_ptr, # [num_tokens, num_heads, head_size + head_size_v]
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size_v]
|
||||
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
kv_stride: tl.int64,
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size_kv: tl.constexpr,
|
||||
head_size_k: tl.constexpr,
|
||||
head_size_v: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
@ -211,24 +214,51 @@ def reshape_and_cache_kernel_flash_diffkv(
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_kv_idx = token_idx * kv_stride
|
||||
src_key_idx = token_idx * key_stride + tile_i * head_size_k
|
||||
src_value_idx = token_idx * value_stride + tile_i * head_size_v
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
|
||||
# [TILE_SIZE]
|
||||
kv_tile = tl.load(
|
||||
kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv)
|
||||
tgt_idx = (
|
||||
block_idx * block_stride
|
||||
+ block_offset * page_stride
|
||||
+ tile_i * (head_size_k + head_size_v)
|
||||
)
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
|
||||
if FP8_KV_CACHE:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(
|
||||
value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + tile_pos,
|
||||
kv_tile,
|
||||
mask=tile_pos < (num_heads * head_size_kv),
|
||||
kv_cache_ptr + tgt_idx + tile_offs,
|
||||
key_tile,
|
||||
mask=tile_offs < head_size_k,
|
||||
)
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
|
||||
value_tile,
|
||||
mask=tile_offs < head_size_v,
|
||||
)
|
||||
return
|
||||
|
||||
@ -243,13 +273,13 @@ def triton_reshape_and_cache_flash_diffkv(
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
kv = torch.cat([key, value], dim=-1).contiguous()
|
||||
num_heads = kv.shape[1]
|
||||
head_size_kv = kv.shape[2]
|
||||
num_heads = key.shape[1]
|
||||
head_size_k = key.shape[2]
|
||||
head_size_v = value.shape[2]
|
||||
block_size = kv_cache.shape[1]
|
||||
n = num_heads * head_size_kv
|
||||
|
||||
kv_stride = kv.stride()[0]
|
||||
k_stride = key.stride()[0]
|
||||
v_stride = value.stride()[0]
|
||||
block_stride = kv_cache.stride()[0]
|
||||
page_stride = kv_cache.stride()[1]
|
||||
|
||||
@ -272,13 +302,20 @@ def triton_reshape_and_cache_flash_diffkv(
|
||||
)
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert not FP8_KV_CACHE, (
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint8,
|
||||
torch.float8_e4m3fnuz,
|
||||
], (
|
||||
"unsupported dtype of KV cache tensor, got "
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32."
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
)
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
TILE_SIZE = max(head_size_k, head_size_v)
|
||||
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
@ -292,21 +329,24 @@ def triton_reshape_and_cache_flash_diffkv(
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (
|
||||
slot_mapping.shape[0],
|
||||
triton.cdiv(n, meta["TILE_SIZE"]),
|
||||
num_heads,
|
||||
)
|
||||
|
||||
reshape_and_cache_kernel_flash_diffkv[grid](
|
||||
kv_ptr=kv,
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
kv_cache_ptr=kv_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
kv_stride=kv_stride,
|
||||
key_stride=k_stride,
|
||||
value_stride=v_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size_kv=head_size_kv,
|
||||
head_size_k=head_size_k,
|
||||
head_size_v=head_size_v,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
|
||||
@ -29,8 +29,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.layers.static_sink_attention import StaticSinkAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
@ -41,7 +41,6 @@ from vllm.distributed import (
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -82,11 +81,7 @@ from vllm.model_executor.models.utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullDiffkvAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
|
||||
def check_ffn_act_fn(act_fn: str):
|
||||
@ -96,133 +91,6 @@ def check_ffn_act_fn(act_fn: str):
|
||||
)
|
||||
|
||||
|
||||
class DiffkvAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
head_size_v: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
cache_config,
|
||||
quant_config,
|
||||
logits_soft_cap,
|
||||
per_layer_sliding_window,
|
||||
prefix,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
attn_backend,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.head_size_v = head_size_v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input():
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupport Error, currently only flash_diffkv_attn "
|
||||
"backend with output buffer is supported"
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
# Only support for full attention now.
|
||||
assert self.sliding_window is None
|
||||
return FullDiffkvAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -673,7 +541,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguDiffkvAttention(nn.Module):
|
||||
class OpenPanguSinkAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -777,19 +645,19 @@ class OpenPanguDiffkvAttention(nn.Module):
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels)
|
||||
self.attn = DiffkvAttention(
|
||||
self.attn = StaticSinkAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.v_channels,
|
||||
self.scaling,
|
||||
sink_len=self.param_sink_number,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend=FlashDiffkvAttentionBackend,
|
||||
attn_backend=FlashAttentionBackend,
|
||||
head_size_v=self.v_channels,
|
||||
)
|
||||
|
||||
if self.param_sink_number > 0:
|
||||
@ -919,19 +787,13 @@ class OpenPanguDiffkvAttention(nn.Module):
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
def get_sink_kv(self) -> dict[str, torch.Tensor]:
|
||||
if self.param_sink_number == 0:
|
||||
raise ValueError("No sink_key and sink_value when param_sink_number == 0")
|
||||
|
||||
def post_weight_load(self) -> None:
|
||||
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
|
||||
param_sink_key = self.k_layernorm(self.param_sink_key)
|
||||
else:
|
||||
param_sink_key = self.param_sink_key
|
||||
|
||||
return {
|
||||
"sink_key": param_sink_key,
|
||||
"sink_value": self.param_sink_value,
|
||||
}
|
||||
self.attn.update_sink_kv(param_sink_key, self.param_sink_value)
|
||||
|
||||
|
||||
class OpenPanguDecoderLayer(nn.Module):
|
||||
@ -1001,7 +863,7 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
"rope_type": "default",
|
||||
"rope_theta": config.rope_theta,
|
||||
}
|
||||
self.self_attn = OpenPanguDiffkvAttention(
|
||||
self.self_attn = OpenPanguSinkAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -1359,8 +1221,17 @@ class OpenPanguModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
self.post_weight_load()
|
||||
return loaded_params
|
||||
|
||||
def post_weight_load(self) -> None:
|
||||
for name, module in self.named_modules():
|
||||
if module is self:
|
||||
continue
|
||||
if hasattr(module, "post_weight_load"):
|
||||
module.post_weight_load()
|
||||
|
||||
|
||||
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8,
|
||||
get_flash_attn_version,
|
||||
@ -105,28 +108,48 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
head_size_v: int | None = None,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
if head_size_v is None or head_size == head_size_v:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
else:
|
||||
return (
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size + head_size_v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
diff_kv: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (2, 0, 1, 3, 4, 5)
|
||||
if not diff_kv:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (2, 0, 1, 3, 4, 5)
|
||||
else:
|
||||
# (num_blocks, num_layers, block_size,
|
||||
# num_kv_heads, head_size + head_size_v)
|
||||
return (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
if not diff_kv:
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
else:
|
||||
# (num_blocks, num_kv_heads, num_layers,
|
||||
# block_size, head_size + head_size_v)
|
||||
return (2, 3, 0, 1, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
@ -576,11 +599,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
or [num_tokens, num_kv_heads, head_size_v]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
or [num_blocks, block_size, num_kv_heads, head_size + head_size_v]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
or [num_tokens, num_heads * head_size_v]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
@ -623,7 +649,13 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.head_size == kv_cache.shape[-1]:
|
||||
# Same head_size for K and V
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
else:
|
||||
# Different head_size for K and V
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
@ -640,16 +672,29 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
if self.head_size == kv_cache.shape[-1]:
|
||||
# kv_cache update for same head_size K and V
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# kv_cache update for different head_size K and V
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -210,11 +210,6 @@ class Scheduler(SchedulerInterface):
|
||||
hash_block_size=self.block_size,
|
||||
metrics_collector=self.kv_metrics_collector,
|
||||
)
|
||||
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
|
||||
if sink_len > 0:
|
||||
assert sink_len % self.block_size == 0
|
||||
num_sink_block = sink_len // self.block_size
|
||||
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
|
||||
@ -12,10 +12,10 @@ from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
FullDiffkvAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SinkFullAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
@ -317,8 +317,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec,
|
||||
FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec,
|
||||
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
@ -785,14 +784,35 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||
|
||||
|
||||
class SinkFullAttentionManager(FullAttentionManager):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size
|
||||
)
|
||||
sink_len = kv_cache_spec.sink_len
|
||||
if sink_len > 0:
|
||||
assert sink_len % self.block_size == 0
|
||||
num_sink_block = sink_len // self.block_size
|
||||
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(
|
||||
num_sink_block
|
||||
)
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
FullDiffkvAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
SinkFullAttentionSpec: SinkFullAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -80,6 +80,7 @@ class AttentionSpec(KVCacheSpec):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
head_size_v: int | None = None
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
"""
|
||||
@ -92,6 +93,10 @@ class FullAttentionSpec(AttentionSpec):
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_size_v is None:
|
||||
object.__setattr__(self, "head_size_v", self.head_size)
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
@ -124,88 +129,6 @@ class FullAttentionSpec(AttentionSpec):
|
||||
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
attention_chunk_size = set(
|
||||
spec.attention_chunk_size
|
||||
for spec in specs
|
||||
if spec.attention_chunk_size is not None
|
||||
)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||
)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec."
|
||||
)
|
||||
assert (merged_spec.sliding_window is not None) + (
|
||||
merged_spec.attention_chunk_size is not None
|
||||
) <= 1, (
|
||||
"Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported."
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullDiffkvAttentionSpec(AttentionSpec):
|
||||
head_size_v: int
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullDiffkvAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullDiffkvAttentionSpec objects into a single
|
||||
FullDiffkvAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullDiffkvAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
@ -376,6 +299,56 @@ class CrossAttentionSpec(AttentionSpec):
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(forzen=True)
|
||||
class SinkFullAttentionSpec(FullAttentionSpec):
|
||||
sink_len: int | None = None
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
attention_chunk_size = set(
|
||||
spec.attention_chunk_size
|
||||
for spec in specs
|
||||
if spec.attention_chunk_size is not None
|
||||
)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||
)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
sink_len=specs[0].sink_len,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec."
|
||||
)
|
||||
assert (merged_spec.sliding_window is not None) + (
|
||||
merged_spec.attention_chunk_size is not None
|
||||
) <= 1, (
|
||||
"Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported."
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
"""
|
||||
|
||||
@ -263,6 +263,7 @@ class MultiGroupBlockTable:
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
sink_len: int = 0,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@ -292,7 +293,7 @@ class MultiGroupBlockTable:
|
||||
block_size,
|
||||
max_num_reqs,
|
||||
max(
|
||||
cdiv(max_model_len, block_size * total_cp_world_size),
|
||||
cdiv(max_model_len + sink_len, block_size * total_cp_world_size),
|
||||
1 + num_speculative_tokens,
|
||||
),
|
||||
max_num_batched_tokens,
|
||||
|
||||
@ -101,16 +101,25 @@ def _reshape_kv_cache(
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
if hasattr(kv_cache_spec, "head_size_v"):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
**stride_kwargs
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
@ -143,7 +143,7 @@ class InputBatch:
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len + sink_len,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
@ -151,6 +151,7 @@ class InputBatch:
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -27,9 +27,6 @@ from vllm.attention.backends.abstract import (
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
@ -5209,16 +5206,25 @@ class GPUModelRunner(
|
||||
)
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
if hasattr(kv_cache_spec, "head_size_v"):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
**stride_kwargs
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
@ -5410,7 +5416,6 @@ class GPUModelRunner(
|
||||
kv_caches = self.initialize_kv_cache_tensors(
|
||||
kv_cache_config, kernel_block_sizes
|
||||
)
|
||||
self.prepare_sink_kv_cache(kv_caches)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
@ -5501,36 +5506,3 @@ class GPUModelRunner(
|
||||
self.transfer_event.record()
|
||||
self.transfer_event.synchronize()
|
||||
return pinned.tolist()
|
||||
|
||||
def prepare_sink_kv_cache(self, kv_caches) -> None:
|
||||
if self.sink_len == 0:
|
||||
return
|
||||
|
||||
def find_module_by_name(model, target_name: str):
|
||||
for name, module in model.named_modules():
|
||||
if name == target_name:
|
||||
return module
|
||||
raise KeyError(f"Module '{target_name}' not found")
|
||||
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
layer_prefix = layer_name.rsplit(".", 1)[0]
|
||||
self_attn_module = find_module_by_name(self.model, layer_prefix)
|
||||
if not hasattr(self_attn_module, "get_sink_kv"):
|
||||
continue
|
||||
else:
|
||||
sink_kv = self_attn_module.get_sink_kv()
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.vllm_config.cache_config.block_size,
|
||||
self.sink_len + self.vllm_config.cache_config.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
sink_kv["sink_key"],
|
||||
sink_kv["sink_value"],
|
||||
kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self_attn_module.attn.kv_cache_dtype,
|
||||
self_attn_module.attn._k_scale,
|
||||
self_attn_module.attn._v_scale,
|
||||
)
|
||||
|
||||
@ -190,17 +190,25 @@ class KVConnectorModelRunnerMixin:
|
||||
return False
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
if hasattr(kv_cache_spec, "head_size_v"):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
1234,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
include_num_layers_dimension=True,
|
||||
**stride_kwargs,
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
return False
|
||||
@ -257,12 +265,19 @@ class KVConnectorModelRunnerMixin:
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
if hasattr(kv_cache_spec, "head_size_v"):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# prepend a num_layers dimension into the shape
|
||||
@ -270,7 +285,8 @@ class KVConnectorModelRunnerMixin:
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
include_num_layers_dimension=True,
|
||||
**stride_kwargs,
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user