mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 09:57:17 +08:00
Refactor code, move static sink logics to builder
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
4ad3f75875
commit
cf58a62099
@ -17,6 +17,8 @@ from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
@ -48,9 +50,23 @@ def create_static_sink_attention_backend(
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
model_config = vllm_config.model_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.sink_len = sink_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||
self.sink_block_table = torch.arange(
|
||||
self.max_num_blocks = cdiv(
|
||||
model_config.max_model_len, vllm_config.cache_config.block_size
|
||||
)
|
||||
self.block_table_with_sink = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_seqs,
|
||||
self.max_num_blocks + self.num_sink_blocks,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
|
||||
1,
|
||||
self.num_sink_blocks + 1,
|
||||
device=device,
|
||||
@ -72,6 +88,14 @@ def create_static_sink_attention_backend(
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.max_seq_len + self.sink_len
|
||||
)
|
||||
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
self.block_table_with_sink[
|
||||
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
|
||||
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
|
||||
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
|
||||
:num_reqs
|
||||
]
|
||||
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
@ -84,7 +108,8 @@ def create_static_sink_attention_backend(
|
||||
return attn_backend
|
||||
|
||||
|
||||
class StaticSinkAttention(Attention):
|
||||
@CustomOp.register("static_sink_attention")
|
||||
class StaticSinkAttention(Attention, CustomOp):
|
||||
"""
|
||||
Attention with static sink tokens
|
||||
"""
|
||||
@ -118,7 +143,8 @@ class StaticSinkAttention(Attention):
|
||||
underlying_attn_backend,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
super().__init__(
|
||||
Attention.__init__(
|
||||
self=self,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
@ -126,6 +152,7 @@ class StaticSinkAttention(Attention):
|
||||
attn_backend=attn_backend,
|
||||
**kwargs,
|
||||
)
|
||||
CustomOp.__init__(self)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
@ -137,7 +164,7 @@ class StaticSinkAttention(Attention):
|
||||
self.sink_key = sink_key
|
||||
self.sink_value = sink_value
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -154,6 +181,18 @@ class StaticSinkAttention(Attention):
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward_native(query, key, value, output_shape)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def populate_sink_kv(self, self_kv_cache):
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.block_size,
|
||||
|
||||
@ -801,147 +801,6 @@ class SinkFullAttentionManager(FullAttentionManager):
|
||||
num_sink_block = sink_len // self.block_size
|
||||
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (
|
||||
num_required_blocks
|
||||
- len(new_computed_blocks)
|
||||
- len(self.req_to_blocks[request_id])
|
||||
)
|
||||
# Number of sink blocks is calculated into num_new_blocks
|
||||
if len(self.req_to_blocks[request_id]) > 0:
|
||||
num_new_blocks = num_new_blocks + len(self.sink_blocks)
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
|
||||
)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
# Append both sink blocks and hitted prefix cache blocks
|
||||
req_blocks.extend(self.sink_blocks)
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
# For existing requests, number of sink blocks is calculated into
|
||||
# num_new_blocks
|
||||
if len(req_blocks) > 0:
|
||||
num_new_blocks = num_new_blocks + len(self.sink_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
# For new requests, allocate sink blocks
|
||||
if len(req_blocks) == 0:
|
||||
req_blocks.extend(self.sink_blocks + new_blocks)
|
||||
else:
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
# Do not cache sink blocks
|
||||
blocks=self.req_to_blocks[request.request_id][len(self.sink_blocks) :],
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
# Do not free sink blocks
|
||||
if len(req_blocks) > 0:
|
||||
req_blocks = req_blocks[len(self.sink_blocks) :]
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(req_blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
|
||||
@ -23,7 +23,6 @@ class BlockTable:
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
sink_len: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -64,8 +63,6 @@ class BlockTable:
|
||||
self.use_hybrid_blocks = True
|
||||
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
|
||||
self.sink_block_len = sink_len // self.block_size
|
||||
self.max_num_blocks_per_req = self.max_num_blocks_per_req + self.sink_block_len
|
||||
|
||||
self.block_table = self._make_buffer(
|
||||
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
|
||||
@ -154,7 +151,7 @@ class BlockTable:
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
) + self.sink_block_len
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
@ -180,10 +177,9 @@ class BlockTable:
|
||||
mask, slot_mapping, -1
|
||||
)
|
||||
else:
|
||||
# When self.sink_block_len > 0, we need to shift block table indices
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
) + self.sink_block_len
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
@ -267,7 +263,6 @@ class MultiGroupBlockTable:
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
sink_len: int = 0,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@ -305,7 +300,6 @@ class MultiGroupBlockTable:
|
||||
device,
|
||||
kernel_block_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
@ -101,28 +101,16 @@ def _reshape_kv_cache(
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
if (
|
||||
hasattr(kv_cache_spec, "head_size_v")
|
||||
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
|
||||
):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
**stride_kwargs
|
||||
)
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
@ -96,7 +96,6 @@ class InputBatch:
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
sink_len: int = 0,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@ -151,7 +150,6 @@ class InputBatch:
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -332,10 +332,6 @@ class GPUModelRunner(
|
||||
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
|
||||
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
self.sink_len = getattr(
|
||||
self.vllm_config.model_config.hf_config, "param_sink_number", 0
|
||||
)
|
||||
assert self.sink_len % self.cache_config.block_size == 0
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
self.use_alibi = model_config.uses_alibi
|
||||
|
||||
@ -459,7 +455,6 @@ class GPUModelRunner(
|
||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
|
||||
sink_len=self.sink_len,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
@ -5079,7 +5074,7 @@ class GPUModelRunner(
|
||||
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=self.num_spec_tokens,
|
||||
sink_len=self.sink_len,
|
||||
# sink_len=self.sink_len,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
@ -5206,28 +5201,16 @@ class GPUModelRunner(
|
||||
)
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
if (
|
||||
hasattr(kv_cache_spec, "head_size_v")
|
||||
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
|
||||
):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
**stride_kwargs
|
||||
)
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
@ -190,28 +190,17 @@ class KVConnectorModelRunnerMixin:
|
||||
return False
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
if (
|
||||
hasattr(kv_cache_spec, "head_size_v")
|
||||
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
|
||||
):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
1234,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True,
|
||||
**stride_kwargs,
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
return False
|
||||
@ -268,22 +257,12 @@ class KVConnectorModelRunnerMixin:
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
if (
|
||||
hasattr(kv_cache_spec, "head_size_v")
|
||||
and kv_cache_spec.head_size_v != kv_cache_spec.head_size
|
||||
):
|
||||
kwargs = {"head_size_v": kv_cache_spec.head_size_v}
|
||||
stride_kwargs = {"diff_kv": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
stride_kwargs = {}
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# prepend a num_layers dimension into the shape
|
||||
@ -292,7 +271,6 @@ class KVConnectorModelRunnerMixin:
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True,
|
||||
**stride_kwargs,
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user