mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 02:37:04 +08:00
Refacotr code. Extent FLASH_ATTN to support different KV size and create a new StaticSinkAttention for sink token logics
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
de538d3b8f
commit
b565203d92
@ -42,9 +42,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
FLASH_DIFFKV_ATTN = (
|
|
||||||
"vllm.v1.attention.backends.flash_diffkv_attn.FlashDiffkvAttentionBackend"
|
|
||||||
)
|
|
||||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||||
|
|||||||
@ -191,6 +191,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: str | None = None,
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
attn_backend: type[AttentionBackend] | None = None,
|
attn_backend: type[AttentionBackend] | None = None,
|
||||||
|
head_size_v: int | None = None,
|
||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -232,6 +233,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||||
@ -370,6 +372,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
query, _ = self.query_quant(query, self._q_scale)
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
|
if output_shape is None:
|
||||||
|
output_shape = torch.Size(
|
||||||
|
(*query.shape[:-1], self.num_heads * self.head_size_v)
|
||||||
|
)
|
||||||
output_shape = output_shape if output_shape is not None else query.shape
|
output_shape = output_shape if output_shape is not None else query.shape
|
||||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||||
hidden_size = output_shape[-1]
|
hidden_size = output_shape[-1]
|
||||||
@ -377,11 +383,11 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
# CPU overheads from the non-CUDA-graph regions.
|
# CPU overheads from the non-CUDA-graph regions.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
output = output.view(-1, self.num_heads, self.head_size)
|
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@ -456,6 +462,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
|
head_size_v=self.head_size_v,
|
||||||
dtype=self.kv_cache_torch_dtype,
|
dtype=self.kv_cache_torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
225
vllm/attention/layers/static_sink_attention.py
Normal file
225
vllm/attention/layers/static_sink_attention.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash_diffkv,
|
||||||
|
)
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
AttentionSpec,
|
||||||
|
KVCacheSpec,
|
||||||
|
SinkFullAttentionSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_static_sink_attention_backend(
|
||||||
|
underlying_attn_backend: type[AttentionBackend],
|
||||||
|
sink_len: int = 0,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = "StaticSink_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
self.sink_len = sink_len
|
||||||
|
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||||
|
self.sink_block_table = torch.arange(
|
||||||
|
1,
|
||||||
|
self.num_sink_blocks + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
common_attn_metadata.seq_lens[:] = (
|
||||||
|
common_attn_metadata.seq_lens + self.sink_len
|
||||||
|
)
|
||||||
|
common_attn_metadata.seq_lens_cpu = (
|
||||||
|
common_attn_metadata.seq_lens_cpu + self.sink_len
|
||||||
|
)
|
||||||
|
common_attn_metadata.max_seq_len = (
|
||||||
|
common_attn_metadata.max_seq_len + self.sink_len
|
||||||
|
)
|
||||||
|
|
||||||
|
blk_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
sink_block_table = self.sink_block_table[None, :].expand(
|
||||||
|
blk_table_tensor.shape[0], -1
|
||||||
|
)
|
||||||
|
blk_table_tensor_clone = blk_table_tensor.clone()
|
||||||
|
blk_table_tensor[:, self.num_sink_blocks :] = blk_table_tensor_clone[
|
||||||
|
:, : -self.num_sink_blocks
|
||||||
|
]
|
||||||
|
blk_table_tensor[:, : self.num_sink_blocks] = sink_block_table
|
||||||
|
|
||||||
|
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=StaticSinkAttentionBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class StaticSinkAttention(Attention):
|
||||||
|
"""
|
||||||
|
Attention with static sink tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
sink_len: int,
|
||||||
|
attn_backend: type[AttentionBackend] | None = None,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if attn_backend is not None:
|
||||||
|
underlying_attn_backend = attn_backend
|
||||||
|
else:
|
||||||
|
underlying_attn_backend = get_attn_backend(
|
||||||
|
head_size, dtype, kv_cache_dtype, block_size
|
||||||
|
)
|
||||||
|
attn_backend = create_static_sink_attention_backend(
|
||||||
|
underlying_attn_backend,
|
||||||
|
sink_len=sink_len,
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sink_len = sink_len
|
||||||
|
self.block_size = block_size
|
||||||
|
self.sink_populated = False
|
||||||
|
self.sink_key = None
|
||||||
|
self.sink_value = None
|
||||||
|
|
||||||
|
def update_sink_kv(self, sink_key, sink_value) -> None:
|
||||||
|
self.sink_key = sink_key
|
||||||
|
self.sink_value = sink_value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output_shape: torch.size | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert self.sink_key is not None and self.sink_value is not None, (
|
||||||
|
"sink_key and sink_value have not been prepared"
|
||||||
|
)
|
||||||
|
if not self.sink_populated:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||||
|
|
||||||
|
return super().forward(query, key, value, output_shape)
|
||||||
|
|
||||||
|
def populate_sink_kv(self, self_kv_cache):
|
||||||
|
sink_kv_slot_mapping = torch.arange(
|
||||||
|
self.block_size,
|
||||||
|
self.sink_len + self.block_size,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
triton_reshape_and_cache_flash_diffkv(
|
||||||
|
self.sink_key,
|
||||||
|
self.sink_value,
|
||||||
|
self_kv_cache,
|
||||||
|
sink_kv_slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self._k_scale,
|
||||||
|
self._v_scale,
|
||||||
|
)
|
||||||
|
# We only populate the sink_key and sink_value once
|
||||||
|
self.sink_populated = True
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||||
|
# Block size may get updated after model loading, refresh it
|
||||||
|
block_size = vllm_config.cache_config.block_size
|
||||||
|
# Should not be called for enc-dec or encoder-only attention.
|
||||||
|
assert self.attn_type == AttentionType.DECODER
|
||||||
|
|
||||||
|
return SinkFullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
head_size_v=self.head_size_v,
|
||||||
|
sink_len=self.sink_len,
|
||||||
|
dtype=self.kv_cache_torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_populate_sink(
|
||||||
|
self_kv_cache: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
if self.sink_populated or self_kv_cache.numel() == 0:
|
||||||
|
return
|
||||||
|
self.populate_sink_kv(self_kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_populate_sink_fake(
|
||||||
|
self_kv_cache: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="maybe_populate_sink",
|
||||||
|
op_func=maybe_populate_sink,
|
||||||
|
mutates_args=["self_kv_cache"],
|
||||||
|
fake_impl=maybe_populate_sink_fake,
|
||||||
|
)
|
||||||
@ -186,17 +186,20 @@ def triton_reshape_and_cache_flash(
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def reshape_and_cache_kernel_flash_diffkv(
|
def reshape_and_cache_kernel_flash_diffkv(
|
||||||
kv_ptr, # [num_tokens, num_heads, head_size + head_size_v]
|
key_ptr, # [num_tokens, num_heads, head_size]
|
||||||
|
value_ptr, # [num_tokens, num_heads, head_size_v]
|
||||||
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||||
slot_mapping_ptr, # [num_tokens]
|
slot_mapping_ptr, # [num_tokens]
|
||||||
k_scale, # float32
|
k_scale, # float32
|
||||||
v_scale, # float32
|
v_scale, # float32
|
||||||
# strides
|
# strides
|
||||||
kv_stride: tl.int64,
|
key_stride: tl.int64,
|
||||||
|
value_stride: tl.int64,
|
||||||
block_stride: tl.int64,
|
block_stride: tl.int64,
|
||||||
page_stride: tl.int64,
|
page_stride: tl.int64,
|
||||||
num_heads: tl.constexpr,
|
num_heads: tl.constexpr,
|
||||||
head_size_kv: tl.constexpr,
|
head_size_k: tl.constexpr,
|
||||||
|
head_size_v: tl.constexpr,
|
||||||
block_size: tl.constexpr,
|
block_size: tl.constexpr,
|
||||||
# FP8 flags
|
# FP8 flags
|
||||||
FP8_KV_CACHE: tl.constexpr,
|
FP8_KV_CACHE: tl.constexpr,
|
||||||
@ -211,24 +214,51 @@ def reshape_and_cache_kernel_flash_diffkv(
|
|||||||
|
|
||||||
tile_i = tl.program_id(axis=1)
|
tile_i = tl.program_id(axis=1)
|
||||||
tile_offs = tl.arange(0, TILE_SIZE)
|
tile_offs = tl.arange(0, TILE_SIZE)
|
||||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
|
||||||
|
|
||||||
block_idx = slot_idx // block_size
|
block_idx = slot_idx // block_size
|
||||||
block_offset = slot_idx % block_size
|
block_offset = slot_idx % block_size
|
||||||
|
|
||||||
src_kv_idx = token_idx * kv_stride
|
src_key_idx = token_idx * key_stride + tile_i * head_size_k
|
||||||
|
src_value_idx = token_idx * value_stride + tile_i * head_size_v
|
||||||
|
|
||||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
tgt_idx = (
|
||||||
|
block_idx * block_stride
|
||||||
# [TILE_SIZE]
|
+ block_offset * page_stride
|
||||||
kv_tile = tl.load(
|
+ tile_i * (head_size_k + head_size_v)
|
||||||
kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the key_cache_ptr.dtype.element_ty
|
||||||
|
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||||
|
else:
|
||||||
|
key_tile = key_load
|
||||||
|
|
||||||
|
# [TILE_SIZE]
|
||||||
|
value_load = tl.load(
|
||||||
|
value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v
|
||||||
|
)
|
||||||
|
if FP8_KV_CACHE:
|
||||||
|
if value_load.dtype.is_fp8():
|
||||||
|
value_tile = value_load
|
||||||
|
else:
|
||||||
|
# tl.store will do the correct implicit cast to fp8,
|
||||||
|
# based on the value_cache_ptr.dtype.element_ty
|
||||||
|
value_tile = value_load / tl.load(v_scale)
|
||||||
|
else:
|
||||||
|
value_tile = value_load
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
kv_cache_ptr + tgt_idx + tile_pos,
|
kv_cache_ptr + tgt_idx + tile_offs,
|
||||||
kv_tile,
|
key_tile,
|
||||||
mask=tile_pos < (num_heads * head_size_kv),
|
mask=tile_offs < head_size_k,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
|
||||||
|
value_tile,
|
||||||
|
mask=tile_offs < head_size_v,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -243,13 +273,13 @@ def triton_reshape_and_cache_flash_diffkv(
|
|||||||
k_scale: torch.Tensor, # float32
|
k_scale: torch.Tensor, # float32
|
||||||
v_scale: torch.Tensor, # float32
|
v_scale: torch.Tensor, # float32
|
||||||
):
|
):
|
||||||
kv = torch.cat([key, value], dim=-1).contiguous()
|
num_heads = key.shape[1]
|
||||||
num_heads = kv.shape[1]
|
head_size_k = key.shape[2]
|
||||||
head_size_kv = kv.shape[2]
|
head_size_v = value.shape[2]
|
||||||
block_size = kv_cache.shape[1]
|
block_size = kv_cache.shape[1]
|
||||||
n = num_heads * head_size_kv
|
|
||||||
|
|
||||||
kv_stride = kv.stride()[0]
|
k_stride = key.stride()[0]
|
||||||
|
v_stride = value.stride()[0]
|
||||||
block_stride = kv_cache.stride()[0]
|
block_stride = kv_cache.stride()[0]
|
||||||
page_stride = kv_cache.stride()[1]
|
page_stride = kv_cache.stride()[1]
|
||||||
|
|
||||||
@ -272,13 +302,20 @@ def triton_reshape_and_cache_flash_diffkv(
|
|||||||
)
|
)
|
||||||
|
|
||||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||||
assert not FP8_KV_CACHE, (
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
|
torch.uint8,
|
||||||
|
torch.float8_e4m3fnuz,
|
||||||
|
], (
|
||||||
"unsupported dtype of KV cache tensor, got "
|
"unsupported dtype of KV cache tensor, got "
|
||||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32."
|
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||||
|
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||||
)
|
)
|
||||||
|
|
||||||
# heuristics instead of autotuning
|
# heuristics instead of autotuning
|
||||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
TILE_SIZE = max(head_size_k, head_size_v)
|
||||||
|
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
|
||||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
num_stages = 4
|
num_stages = 4
|
||||||
num_warps = 8
|
num_warps = 8
|
||||||
@ -292,21 +329,24 @@ def triton_reshape_and_cache_flash_diffkv(
|
|||||||
# using cudagraphs
|
# using cudagraphs
|
||||||
grid = lambda meta: (
|
grid = lambda meta: (
|
||||||
slot_mapping.shape[0],
|
slot_mapping.shape[0],
|
||||||
triton.cdiv(n, meta["TILE_SIZE"]),
|
num_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
reshape_and_cache_kernel_flash_diffkv[grid](
|
reshape_and_cache_kernel_flash_diffkv[grid](
|
||||||
kv_ptr=kv,
|
key_ptr=key,
|
||||||
|
value_ptr=value,
|
||||||
kv_cache_ptr=kv_cache,
|
kv_cache_ptr=kv_cache,
|
||||||
slot_mapping_ptr=slot_mapping,
|
slot_mapping_ptr=slot_mapping,
|
||||||
k_scale=k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=v_scale,
|
||||||
# strides
|
# strides
|
||||||
kv_stride=kv_stride,
|
key_stride=k_stride,
|
||||||
|
value_stride=v_stride,
|
||||||
block_stride=block_stride,
|
block_stride=block_stride,
|
||||||
page_stride=page_stride,
|
page_stride=page_stride,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_size_kv=head_size_kv,
|
head_size_k=head_size_k,
|
||||||
|
head_size_v=head_size_v,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
# FP8 flags
|
# FP8 flags
|
||||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||||
|
|||||||
@ -29,8 +29,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
|
from vllm.attention.layer import Attention, AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layers.static_sink_attention import StaticSinkAttention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@ -41,7 +41,6 @@ from vllm.distributed import (
|
|||||||
get_tp_group,
|
get_tp_group,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -82,11 +81,7 @@ from vllm.model_executor.models.utils import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import set_default_rope_theta
|
from vllm.transformers_utils.config import set_default_rope_theta
|
||||||
from vllm.v1.attention.backends.flash_diffkv_attn import FlashDiffkvAttentionBackend
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
from vllm.v1.kv_cache_interface import (
|
|
||||||
FullDiffkvAttentionSpec,
|
|
||||||
KVCacheSpec,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_ffn_act_fn(act_fn: str):
|
def check_ffn_act_fn(act_fn: str):
|
||||||
@ -96,133 +91,6 @@ def check_ffn_act_fn(act_fn: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiffkvAttention(Attention):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
head_size_v: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: int | None = None,
|
|
||||||
alibi_slopes: list[float] | None = None,
|
|
||||||
cache_config: CacheConfig | None = None,
|
|
||||||
quant_config: QuantizationConfig | None = None,
|
|
||||||
logits_soft_cap: float | None = None,
|
|
||||||
per_layer_sliding_window: int | None = None,
|
|
||||||
prefix: str = "",
|
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
kv_sharing_target_layer_name: str | None = None,
|
|
||||||
attn_backend: type[AttentionBackend] | None = None,
|
|
||||||
**extra_impl_args,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
scale,
|
|
||||||
num_kv_heads,
|
|
||||||
alibi_slopes,
|
|
||||||
cache_config,
|
|
||||||
quant_config,
|
|
||||||
logits_soft_cap,
|
|
||||||
per_layer_sliding_window,
|
|
||||||
prefix,
|
|
||||||
attn_type,
|
|
||||||
kv_sharing_target_layer_name,
|
|
||||||
attn_backend,
|
|
||||||
**extra_impl_args,
|
|
||||||
)
|
|
||||||
self.head_size_v = head_size_v
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
output_shape: torch.Size | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
The KV cache is stored inside this class and is accessed via
|
|
||||||
`self.kv_cache`.
|
|
||||||
|
|
||||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
|
||||||
the model runner's `execute_model` method. It is accessed via forward
|
|
||||||
context using
|
|
||||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
|
||||||
"""
|
|
||||||
if self.calculate_kv_scales:
|
|
||||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
|
||||||
output_dtype = query.dtype
|
|
||||||
if self.query_quant is not None:
|
|
||||||
# quantizing with a simple torch operation enables
|
|
||||||
# torch.compile to fuse this into previous ops
|
|
||||||
# which reduces overheads during decoding.
|
|
||||||
# Otherwise queries are quantized using custom ops
|
|
||||||
# which causes decoding overheads
|
|
||||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
|
||||||
|
|
||||||
# check if query quantization is supported
|
|
||||||
if self.impl.supports_quant_query_input():
|
|
||||||
query, _ = self.query_quant(query, self._q_scale)
|
|
||||||
|
|
||||||
if self.use_output:
|
|
||||||
output_shape = output_shape if output_shape is not None else query.shape
|
|
||||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
|
||||||
hidden_size = output_shape[-1]
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
|
||||||
# CPU overheads from the non-CUDA-graph regions.
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
|
||||||
if key is not None:
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
if value is not None:
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
|
||||||
if self.use_direct_call:
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
if isinstance(attn_metadata, dict):
|
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
|
||||||
self.impl.forward(
|
|
||||||
self,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
self_kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
output=output,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
output,
|
|
||||||
self.layer_name,
|
|
||||||
)
|
|
||||||
return output.view(-1, hidden_size)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupport Error, currently only flash_diffkv_attn "
|
|
||||||
"backend with output buffer is supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
||||||
# Block size may get updated after model loading, refresh it
|
|
||||||
block_size = vllm_config.cache_config.block_size
|
|
||||||
# Should not be called for enc-dec or encoder-only attention.
|
|
||||||
assert self.attn_type == AttentionType.DECODER
|
|
||||||
# Only support for full attention now.
|
|
||||||
assert self.sliding_window is None
|
|
||||||
return FullDiffkvAttentionSpec(
|
|
||||||
block_size=block_size,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
head_size_v=self.head_size_v,
|
|
||||||
dtype=self.kv_cache_torch_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenPanguMLP(nn.Module):
|
class OpenPanguMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -673,7 +541,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenPanguDiffkvAttention(nn.Module):
|
class OpenPanguSinkAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@ -777,19 +645,19 @@ class OpenPanguDiffkvAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
sliding_window = None
|
sliding_window = None
|
||||||
|
|
||||||
FlashDiffkvAttentionBackend.set_head_size_v(self.v_channels)
|
self.attn = StaticSinkAttention(
|
||||||
self.attn = DiffkvAttention(
|
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.v_channels,
|
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
sink_len=self.param_sink_number,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
per_layer_sliding_window=sliding_window,
|
per_layer_sliding_window=sliding_window,
|
||||||
attn_type=attn_type,
|
attn_type=attn_type,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
attn_backend=FlashDiffkvAttentionBackend,
|
attn_backend=FlashAttentionBackend,
|
||||||
|
head_size_v=self.v_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.param_sink_number > 0:
|
if self.param_sink_number > 0:
|
||||||
@ -919,19 +787,13 @@ class OpenPanguDiffkvAttention(nn.Module):
|
|||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_sink_kv(self) -> dict[str, torch.Tensor]:
|
def post_weight_load(self) -> None:
|
||||||
if self.param_sink_number == 0:
|
|
||||||
raise ValueError("No sink_key and sink_value when param_sink_number == 0")
|
|
||||||
|
|
||||||
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
|
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
|
||||||
param_sink_key = self.k_layernorm(self.param_sink_key)
|
param_sink_key = self.k_layernorm(self.param_sink_key)
|
||||||
else:
|
else:
|
||||||
param_sink_key = self.param_sink_key
|
param_sink_key = self.param_sink_key
|
||||||
|
|
||||||
return {
|
self.attn.update_sink_kv(param_sink_key, self.param_sink_value)
|
||||||
"sink_key": param_sink_key,
|
|
||||||
"sink_value": self.param_sink_value,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenPanguDecoderLayer(nn.Module):
|
class OpenPanguDecoderLayer(nn.Module):
|
||||||
@ -1001,7 +863,7 @@ class OpenPanguDecoderLayer(nn.Module):
|
|||||||
"rope_type": "default",
|
"rope_type": "default",
|
||||||
"rope_theta": config.rope_theta,
|
"rope_theta": config.rope_theta,
|
||||||
}
|
}
|
||||||
self.self_attn = OpenPanguDiffkvAttention(
|
self.self_attn = OpenPanguSinkAttention(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
@ -1359,8 +1221,17 @@ class OpenPanguModel(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
|
|
||||||
|
self.post_weight_load()
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
def post_weight_load(self) -> None:
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if module is self:
|
||||||
|
continue
|
||||||
|
if hasattr(module, "post_weight_load"):
|
||||||
|
module.post_weight_load()
|
||||||
|
|
||||||
|
|
||||||
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
|
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
|
|||||||
@ -18,6 +18,9 @@ from vllm.attention.backends.abstract import (
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash_diffkv,
|
||||||
|
)
|
||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8,
|
flash_attn_supports_fp8,
|
||||||
get_flash_attn_version,
|
get_flash_attn_version,
|
||||||
@ -105,28 +108,48 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype_str: str = "auto",
|
cache_dtype_str: str = "auto",
|
||||||
|
head_size_v: int | None = None,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
if block_size % 16 != 0:
|
if block_size % 16 != 0:
|
||||||
raise ValueError("Block size must be a multiple of 16.")
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
if head_size_v is None or head_size == head_size_v:
|
||||||
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size + head_size_v,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_stride_order(
|
def get_kv_cache_stride_order(
|
||||||
include_num_layers_dimension: bool = False,
|
include_num_layers_dimension: bool = False,
|
||||||
|
diff_kv: bool = False,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
# `stride_order` indicates the permutation that gets
|
# `stride_order` indicates the permutation that gets
|
||||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||||
cache_layout = get_kv_cache_layout()
|
cache_layout = get_kv_cache_layout()
|
||||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
if not diff_kv:
|
||||||
return (2, 0, 1, 3, 4, 5)
|
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||||
|
return (2, 0, 1, 3, 4, 5)
|
||||||
|
else:
|
||||||
|
# (num_blocks, num_layers, block_size,
|
||||||
|
# num_kv_heads, head_size + head_size_v)
|
||||||
|
return (0, 1, 2, 3, 4)
|
||||||
elif cache_layout == "NHD":
|
elif cache_layout == "NHD":
|
||||||
stride_order = (0, 1, 2, 3, 4)
|
stride_order = (0, 1, 2, 3, 4) if not diff_kv else (0, 1, 2, 3)
|
||||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
if not diff_kv:
|
||||||
return (2, 4, 0, 1, 3, 5)
|
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||||
|
return (2, 4, 0, 1, 3, 5)
|
||||||
|
else:
|
||||||
|
# (num_blocks, num_kv_heads, num_layers,
|
||||||
|
# block_size, head_size + head_size_v)
|
||||||
|
return (2, 3, 0, 1, 4)
|
||||||
elif cache_layout == "HND":
|
elif cache_layout == "HND":
|
||||||
stride_order = (0, 1, 3, 2, 4)
|
stride_order = (0, 1, 3, 2, 4) if not diff_kv else (0, 2, 1, 3)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||||
return stride_order
|
return stride_order
|
||||||
@ -576,11 +599,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
query: shape = [num_tokens, num_heads, head_size]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
or [num_tokens, num_kv_heads, head_size_v]
|
||||||
kv_cache: shape =
|
kv_cache: shape =
|
||||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
|
or [num_blocks, block_size, num_kv_heads, head_size + head_size_v]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
or [num_tokens, num_heads * head_size_v]
|
||||||
NOTE: FP8 quantization, flash-attn expect the size of
|
NOTE: FP8 quantization, flash-attn expect the size of
|
||||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||||
We use torch's .expand() to avoid duplicating values
|
We use torch's .expand() to avoid duplicating values
|
||||||
@ -623,7 +649,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
if self.head_size == kv_cache.shape[-1]:
|
||||||
|
# Same head_size for K and V
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
else:
|
||||||
|
# Different head_size for K and V
|
||||||
|
key_cache = kv_cache[..., : self.head_size]
|
||||||
|
value_cache = kv_cache[..., self.head_size :]
|
||||||
|
|
||||||
# key and value may be None in the case of cross attention. They are
|
# key and value may be None in the case of cross attention. They are
|
||||||
# calculated once based on the output from the encoder and then cached
|
# calculated once based on the output from the encoder and then cached
|
||||||
@ -640,16 +672,29 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
# op uses the slot_mapping's shape to determine the number of
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
# actual tokens.
|
# actual tokens.
|
||||||
reshape_and_cache_flash(
|
if self.head_size == kv_cache.shape[-1]:
|
||||||
key,
|
# kv_cache update for same head_size K and V
|
||||||
value,
|
reshape_and_cache_flash(
|
||||||
key_cache,
|
key,
|
||||||
value_cache,
|
value,
|
||||||
attn_metadata.slot_mapping,
|
key_cache,
|
||||||
self.kv_cache_dtype,
|
value_cache,
|
||||||
layer._k_scale,
|
attn_metadata.slot_mapping,
|
||||||
layer._v_scale,
|
self.kv_cache_dtype,
|
||||||
)
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# kv_cache update for different head_size K and V
|
||||||
|
triton_reshape_and_cache_flash_diffkv(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
# queries are quantized in the attention layer
|
# queries are quantized in the attention layer
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -210,11 +210,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
hash_block_size=self.block_size,
|
hash_block_size=self.block_size,
|
||||||
metrics_collector=self.kv_metrics_collector,
|
metrics_collector=self.kv_metrics_collector,
|
||||||
)
|
)
|
||||||
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
|
|
||||||
if sink_len > 0:
|
|
||||||
assert sink_len % self.block_size == 0
|
|
||||||
num_sink_block = sink_len // self.block_size
|
|
||||||
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
|
|
||||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||||
|
|
||||||
|
|||||||
@ -12,10 +12,10 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
CrossAttentionSpec,
|
CrossAttentionSpec,
|
||||||
FullAttentionSpec,
|
FullAttentionSpec,
|
||||||
FullDiffkvAttentionSpec,
|
|
||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
MambaSpec,
|
MambaSpec,
|
||||||
MLAAttentionSpec,
|
MLAAttentionSpec,
|
||||||
|
SinkFullAttentionSpec,
|
||||||
SlidingWindowSpec,
|
SlidingWindowSpec,
|
||||||
)
|
)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@ -317,8 +317,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kv_cache_spec,
|
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||||
FullAttentionSpec | ChunkedLocalAttentionSpec | FullDiffkvAttentionSpec,
|
|
||||||
), (
|
), (
|
||||||
"FullAttentionManager can only be used for full attention "
|
"FullAttentionManager can only be used for full attention "
|
||||||
"and chunked local attention groups"
|
"and chunked local attention groups"
|
||||||
@ -785,14 +784,35 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||||
|
|
||||||
|
|
||||||
|
class SinkFullAttentionManager(FullAttentionManager):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: KVCacheSpec,
|
||||||
|
block_pool: BlockPool,
|
||||||
|
kv_cache_group_id: int,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size
|
||||||
|
)
|
||||||
|
sink_len = kv_cache_spec.sink_len
|
||||||
|
if sink_len > 0:
|
||||||
|
assert sink_len % self.block_size == 0
|
||||||
|
num_sink_block = sink_len // self.block_size
|
||||||
|
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(
|
||||||
|
num_sink_block
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||||
FullAttentionSpec: FullAttentionManager,
|
FullAttentionSpec: FullAttentionManager,
|
||||||
FullDiffkvAttentionSpec: FullAttentionManager,
|
|
||||||
MLAAttentionSpec: FullAttentionManager,
|
MLAAttentionSpec: FullAttentionManager,
|
||||||
SlidingWindowSpec: SlidingWindowManager,
|
SlidingWindowSpec: SlidingWindowManager,
|
||||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||||
MambaSpec: MambaManager,
|
MambaSpec: MambaManager,
|
||||||
CrossAttentionSpec: CrossAttentionManager,
|
CrossAttentionSpec: CrossAttentionManager,
|
||||||
|
SinkFullAttentionSpec: SinkFullAttentionManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -80,6 +80,7 @@ class AttentionSpec(KVCacheSpec):
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FullAttentionSpec(AttentionSpec):
|
class FullAttentionSpec(AttentionSpec):
|
||||||
|
head_size_v: int | None = None
|
||||||
sliding_window: int | None = None
|
sliding_window: int | None = None
|
||||||
attention_chunk_size: int | None = None
|
attention_chunk_size: int | None = None
|
||||||
"""
|
"""
|
||||||
@ -92,6 +93,10 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
Default to None for not using sliding window attention.
|
Default to None for not using sliding window attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.head_size_v is None:
|
||||||
|
object.__setattr__(self, "head_size_v", self.head_size)
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
@ -124,88 +129,6 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||||
)
|
)
|
||||||
|
|
||||||
sliding_window = set(
|
|
||||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
|
||||||
)
|
|
||||||
attention_chunk_size = set(
|
|
||||||
spec.attention_chunk_size
|
|
||||||
for spec in specs
|
|
||||||
if spec.attention_chunk_size is not None
|
|
||||||
)
|
|
||||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
|
||||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
|
||||||
)
|
|
||||||
merged_spec = cls(
|
|
||||||
block_size=specs[0].block_size,
|
|
||||||
num_kv_heads=specs[0].num_kv_heads,
|
|
||||||
head_size=specs[0].head_size,
|
|
||||||
dtype=specs[0].dtype,
|
|
||||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
|
||||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
|
||||||
)
|
|
||||||
for spec in specs:
|
|
||||||
for f in fields(AttentionSpec):
|
|
||||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
|
||||||
"All attention layers in the same KV cache group must have "
|
|
||||||
"the same attention spec."
|
|
||||||
)
|
|
||||||
assert (merged_spec.sliding_window is not None) + (
|
|
||||||
merged_spec.attention_chunk_size is not None
|
|
||||||
) <= 1, (
|
|
||||||
"Model with both sliding window layers and chunked local attention "
|
|
||||||
"layers is not supported."
|
|
||||||
)
|
|
||||||
return merged_spec
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FullDiffkvAttentionSpec(AttentionSpec):
|
|
||||||
head_size_v: int
|
|
||||||
sliding_window: int | None = None
|
|
||||||
attention_chunk_size: int | None = None
|
|
||||||
|
|
||||||
"""
|
|
||||||
When hybrid allocator is disabled and the model contains both full
|
|
||||||
attention layers and sliding window attention layers, sliding
|
|
||||||
window attention are regarded as full attention in KV cache manager
|
|
||||||
(blocks are allocated for all tokens), while computed as sliding window
|
|
||||||
attention in model runner.
|
|
||||||
In this case, we use FullDiffkvAttentionSpec and record the sliding window size.
|
|
||||||
Default to None for not using sliding window attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
|
||||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
|
||||||
# Note(hc): each dcp rank only need save
|
|
||||||
# (max_model_len//dcp_world_size) tokens locally.
|
|
||||||
if dcp_world_size > 1:
|
|
||||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
|
||||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
|
|
||||||
if len(window_sizes) == 0:
|
|
||||||
return None
|
|
||||||
elif len(window_sizes) == 1:
|
|
||||||
return window_sizes.pop()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"All attention layers in the same KV cache group must have the "
|
|
||||||
"same window size."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def merge(cls, specs: list[Self]) -> Self:
|
|
||||||
"""
|
|
||||||
Merge a list of FullDiffkvAttentionSpec objects into a single
|
|
||||||
FullDiffkvAttentionSpec object.
|
|
||||||
"""
|
|
||||||
assert all(isinstance(spec, FullDiffkvAttentionSpec) for spec in specs), (
|
|
||||||
"All attention layers in the same KV cache group must be "
|
|
||||||
"FullDiffkvAttentionSpec."
|
|
||||||
)
|
|
||||||
|
|
||||||
sliding_window = set(
|
sliding_window = set(
|
||||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||||
)
|
)
|
||||||
@ -376,6 +299,56 @@ class CrossAttentionSpec(AttentionSpec):
|
|||||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(forzen=True)
|
||||||
|
class SinkFullAttentionSpec(FullAttentionSpec):
|
||||||
|
sink_len: int | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge(cls, specs: list[Self]) -> Self:
|
||||||
|
"""
|
||||||
|
Merge a list of FullAttentionSpec objects into a single
|
||||||
|
FullAttentionSpec object.
|
||||||
|
"""
|
||||||
|
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||||
|
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||||
|
)
|
||||||
|
|
||||||
|
sliding_window = set(
|
||||||
|
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||||
|
)
|
||||||
|
attention_chunk_size = set(
|
||||||
|
spec.attention_chunk_size
|
||||||
|
for spec in specs
|
||||||
|
if spec.attention_chunk_size is not None
|
||||||
|
)
|
||||||
|
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||||
|
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||||
|
)
|
||||||
|
merged_spec = cls(
|
||||||
|
block_size=specs[0].block_size,
|
||||||
|
num_kv_heads=specs[0].num_kv_heads,
|
||||||
|
head_size=specs[0].head_size,
|
||||||
|
head_size_v=specs[0].head_size_v,
|
||||||
|
sink_len=specs[0].sink_len,
|
||||||
|
dtype=specs[0].dtype,
|
||||||
|
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||||
|
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||||
|
)
|
||||||
|
for spec in specs:
|
||||||
|
for f in fields(AttentionSpec):
|
||||||
|
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||||
|
"All attention layers in the same KV cache group must have "
|
||||||
|
"the same attention spec."
|
||||||
|
)
|
||||||
|
assert (merged_spec.sliding_window is not None) + (
|
||||||
|
merged_spec.attention_chunk_size is not None
|
||||||
|
) <= 1, (
|
||||||
|
"Model with both sliding window layers and chunked local attention "
|
||||||
|
"layers is not supported."
|
||||||
|
)
|
||||||
|
return merged_spec
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -263,6 +263,7 @@ class MultiGroupBlockTable:
|
|||||||
kernel_block_sizes: list[int],
|
kernel_block_sizes: list[int],
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
cp_kv_cache_interleave_size: int = 1,
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
|
sink_len: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Note(hc): each dcp rank only store
|
# Note(hc): each dcp rank only store
|
||||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||||
@ -292,7 +293,7 @@ class MultiGroupBlockTable:
|
|||||||
block_size,
|
block_size,
|
||||||
max_num_reqs,
|
max_num_reqs,
|
||||||
max(
|
max(
|
||||||
cdiv(max_model_len, block_size * total_cp_world_size),
|
cdiv(max_model_len + sink_len, block_size * total_cp_world_size),
|
||||||
1 + num_speculative_tokens,
|
1 + num_speculative_tokens,
|
||||||
),
|
),
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
|
|||||||
@ -101,16 +101,25 @@ def _reshape_kv_cache(
|
|||||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||||
|
|
||||||
attn_backend = attn_backends[layer_name]
|
attn_backend = attn_backends[layer_name]
|
||||||
|
if hasattr(kv_cache_spec, "head_size_v"):
|
||||||
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
|
stride_kwargs = {"diff_kv": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
stride_kwargs = {}
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
kv_cache_spec.block_size,
|
kv_cache_spec.block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||||
|
**stride_kwargs
|
||||||
|
)
|
||||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||||
|
|||||||
@ -143,7 +143,7 @@ class InputBatch:
|
|||||||
# Block table.
|
# Block table.
|
||||||
self.block_table = MultiGroupBlockTable(
|
self.block_table = MultiGroupBlockTable(
|
||||||
max_num_reqs=max_num_reqs,
|
max_num_reqs=max_num_reqs,
|
||||||
max_model_len=max_model_len + sink_len,
|
max_model_len=max_model_len,
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
device=device,
|
device=device,
|
||||||
@ -151,6 +151,7 @@ class InputBatch:
|
|||||||
kernel_block_sizes=kernel_block_sizes,
|
kernel_block_sizes=kernel_block_sizes,
|
||||||
num_speculative_tokens=num_speculative_tokens,
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||||
|
sink_len=sink_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
|
|||||||
@ -27,9 +27,6 @@ from vllm.attention.backends.abstract import (
|
|||||||
MultipleOf,
|
MultipleOf,
|
||||||
)
|
)
|
||||||
from vllm.attention.layer import Attention, MLAAttention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
|
||||||
triton_reshape_and_cache_flash_diffkv,
|
|
||||||
)
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
@ -5209,16 +5206,25 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
|
if hasattr(kv_cache_spec, "head_size_v"):
|
||||||
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
|
stride_kwargs = {"diff_kv": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
stride_kwargs = {}
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
kernel_num_blocks,
|
kernel_num_blocks,
|
||||||
kernel_block_size,
|
kernel_block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
cache_dtype_str=self.cache_config.cache_dtype,
|
cache_dtype_str=self.cache_config.cache_dtype,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||||
|
**stride_kwargs
|
||||||
|
)
|
||||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||||
@ -5410,7 +5416,6 @@ class GPUModelRunner(
|
|||||||
kv_caches = self.initialize_kv_cache_tensors(
|
kv_caches = self.initialize_kv_cache_tensors(
|
||||||
kv_cache_config, kernel_block_sizes
|
kv_cache_config, kernel_block_sizes
|
||||||
)
|
)
|
||||||
self.prepare_sink_kv_cache(kv_caches)
|
|
||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
@ -5501,36 +5506,3 @@ class GPUModelRunner(
|
|||||||
self.transfer_event.record()
|
self.transfer_event.record()
|
||||||
self.transfer_event.synchronize()
|
self.transfer_event.synchronize()
|
||||||
return pinned.tolist()
|
return pinned.tolist()
|
||||||
|
|
||||||
def prepare_sink_kv_cache(self, kv_caches) -> None:
|
|
||||||
if self.sink_len == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
def find_module_by_name(model, target_name: str):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if name == target_name:
|
|
||||||
return module
|
|
||||||
raise KeyError(f"Module '{target_name}' not found")
|
|
||||||
|
|
||||||
for layer_name, kv_cache in kv_caches.items():
|
|
||||||
layer_prefix = layer_name.rsplit(".", 1)[0]
|
|
||||||
self_attn_module = find_module_by_name(self.model, layer_prefix)
|
|
||||||
if not hasattr(self_attn_module, "get_sink_kv"):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
sink_kv = self_attn_module.get_sink_kv()
|
|
||||||
sink_kv_slot_mapping = torch.arange(
|
|
||||||
self.vllm_config.cache_config.block_size,
|
|
||||||
self.sink_len + self.vllm_config.cache_config.block_size,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=torch.long,
|
|
||||||
)
|
|
||||||
triton_reshape_and_cache_flash_diffkv(
|
|
||||||
sink_kv["sink_key"],
|
|
||||||
sink_kv["sink_value"],
|
|
||||||
kv_cache,
|
|
||||||
sink_kv_slot_mapping,
|
|
||||||
self_attn_module.attn.kv_cache_dtype,
|
|
||||||
self_attn_module.attn._k_scale,
|
|
||||||
self_attn_module.attn._v_scale,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -190,17 +190,25 @@ class KVConnectorModelRunnerMixin:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
attn_backend = attn_group.backend
|
attn_backend = attn_group.backend
|
||||||
|
if hasattr(kv_cache_spec, "head_size_v"):
|
||||||
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
|
stride_kwargs = {"diff_kv": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
stride_kwargs = {}
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
1234,
|
1234,
|
||||||
kv_cache_spec.block_size,
|
kv_cache_spec.block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
cache_dtype_str=cache_dtype,
|
cache_dtype_str=cache_dtype,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||||
include_num_layers_dimension=True
|
include_num_layers_dimension=True,
|
||||||
|
**stride_kwargs,
|
||||||
)
|
)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
return False
|
return False
|
||||||
@ -257,12 +265,19 @@ class KVConnectorModelRunnerMixin:
|
|||||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||||
|
|
||||||
attn_backend = attn_group.backend
|
attn_backend = attn_group.backend
|
||||||
|
if hasattr(kv_cache_spec, "head_size_v"):
|
||||||
|
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||||
|
stride_kwargs = {"diff_kv": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
stride_kwargs = {}
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
kernel_num_blocks,
|
kernel_num_blocks,
|
||||||
kernel_block_size,
|
kernel_block_size,
|
||||||
kv_cache_spec.num_kv_heads,
|
kv_cache_spec.num_kv_heads,
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
cache_dtype_str=cache_dtype,
|
cache_dtype_str=cache_dtype,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepend a num_layers dimension into the shape
|
# prepend a num_layers dimension into the shape
|
||||||
@ -270,7 +285,8 @@ class KVConnectorModelRunnerMixin:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||||
include_num_layers_dimension=True
|
include_num_layers_dimension=True,
|
||||||
|
**stride_kwargs,
|
||||||
)
|
)
|
||||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||||
except (AttributeError, NotImplementedError):
|
except (AttributeError, NotImplementedError):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user