mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 22:16:32 +08:00
[Optimization] Make new_block_ids None if empty (#23262)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
bbea1cefdd
commit
b029de9902
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
@ -37,7 +37,24 @@ class KVCacheBlocks:
|
||||
tuple(blk1 + blk2
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)))
|
||||
|
||||
def get_block_ids(self) -> tuple[list[int], ...]:
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[False] = False,
|
||||
) -> tuple[list[int], ...]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[True] = True,
|
||||
) -> Optional[tuple[list[int], ...]]:
|
||||
...
|
||||
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: bool = False,
|
||||
):
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
@ -46,6 +63,8 @@ class KVCacheBlocks:
|
||||
* the outer tuple corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
if allow_none and all(len(group) == 0 for group in self.blocks):
|
||||
return None
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
@ -348,10 +367,13 @@ class KVCacheManager:
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_blocks(self, request_id: str) -> KVCacheBlocks:
|
||||
"""Get the blocks of a request."""
|
||||
return KVCacheBlocks(self.coordinator.get_blocks(request_id))
|
||||
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
return self.get_blocks(request_id).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
|
||||
@ -91,7 +91,7 @@ class CachedRequestData:
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
new_block_ids: list[tuple[list[int], ...]]
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]]
|
||||
num_computed_tokens: list[int]
|
||||
|
||||
@property
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
@ -185,7 +185,7 @@ class Scheduler(SchedulerInterface):
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
|
||||
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
@ -288,8 +288,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids[request.request_id] = req_index
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
new_blocks.get_block_ids())
|
||||
req_to_new_blocks[request.request_id] = new_blocks
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
@ -496,8 +495,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||
req_to_new_blocks[request.request_id] = (
|
||||
self.kv_cache_manager.get_blocks(request.request_id))
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
@ -546,8 +545,8 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(req,
|
||||
req_to_new_block_ids[req.request_id])
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids())
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
@ -555,7 +554,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_block_ids,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
@ -628,11 +627,11 @@ class Scheduler(SchedulerInterface):
|
||||
resumed_reqs: list[Request],
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
spec_decode_tokens: dict[str, list[int]],
|
||||
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
|
||||
req_to_new_blocks: dict[str, KVCacheBlocks],
|
||||
) -> CachedRequestData:
|
||||
req_ids: list[str] = []
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[tuple[list[int], ...]] = []
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
|
||||
num_computed_tokens: list[int] = []
|
||||
|
||||
use_connector = self.connector is not None
|
||||
@ -655,7 +654,8 @@ class Scheduler(SchedulerInterface):
|
||||
# out of bounds errors. TODO: Remove this once the KVConnector
|
||||
# is updated to handle token IDs properly.
|
||||
new_token_ids.append([])
|
||||
new_block_ids.append(req_to_new_block_ids[req_id])
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
|
||||
@ -574,11 +574,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Update the block IDs.
|
||||
if not resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip(req_state.block_ids,
|
||||
new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
if new_block_ids is not None:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip(req_state.block_ids,
|
||||
new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
else:
|
||||
assert new_block_ids is not None
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
@ -594,7 +596,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
||||
if new_block_ids is not None:
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
# For the last rank, we don't need to update the token_ids_cpu
|
||||
# because the sampled tokens are already cached.
|
||||
|
||||
@ -418,11 +418,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
if not resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip(req_state.block_ids,
|
||||
new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
if new_block_ids is not None:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
for block_ids, new_ids in zip(req_state.block_ids,
|
||||
new_block_ids):
|
||||
block_ids.extend(new_ids)
|
||||
else:
|
||||
assert new_block_ids is not None
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
@ -438,7 +440,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
||||
if new_block_ids is not None:
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user