mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 09:31:48 +08:00
[Bugfix][Core] Prefix caching causes incorrect outputs due to outdated ComputedBlocksTracker (#18957)
Signed-off-by: 刘全 <quan.liu2@dbappsecurity.com.cn> Co-authored-by: 刘全 <quan.liu2@dbappsecurity.com.cn>
This commit is contained in:
parent
c6703d1e0d
commit
92183b41f3
@ -1041,3 +1041,297 @@ def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
|
||||
for seq in scheduled_seq_group.seq_group.seqs:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
|
||||
def test_remove_seq_from_computed_blocks_tracker():
|
||||
"""
|
||||
Test that computed_blocks_tracker correctly removes stale sequences
|
||||
during scheduling.
|
||||
|
||||
The test covers 9 scheduling branches where stale seqs are removed:
|
||||
- 1 in _schedule_swapped
|
||||
- 1 in _schedule_priority_preemption
|
||||
- 7 in _schedule_prefill
|
||||
|
||||
Each branch is tested to ensure proper cleanup of
|
||||
_seq_id_to_num_tokens_computed.
|
||||
"""
|
||||
# Budget can not schedule in swapped
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
seq_tokens_with_swapped: list[list[int]] = []
|
||||
blocks_to_swap_out: list[tuple[int, int]] = []
|
||||
curr_loras: set[int] = set()
|
||||
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
budget = create_token_budget(token_budget=15)
|
||||
|
||||
seq_length = 16
|
||||
num_seqs = 3
|
||||
for i in range(num_seqs):
|
||||
seq_tokens_with_swapped.append([i] * seq_length)
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_with_swapped[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_with_swapped))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
scheduler._schedule_swapped(budget, curr_loras)
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Prefill schedule don't have a space for another LoRA, so
|
||||
# we ignore this request for now.
|
||||
block_size = 4
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config,
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64,
|
||||
enable_prefix_caching=True)
|
||||
budget = create_token_budget(token_budget=120)
|
||||
num_seqs = 2
|
||||
for i in range(num_seqs):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=seq_length,
|
||||
block_size=block_size,
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_path="abc"))
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_prefills(budget, curr_loras)
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Priority preemption schedule
|
||||
scheduler._schedule_priority_preemption(budget)
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Prefill scheduler does not schedule batches with prompt tokens and
|
||||
# prompt embeddings co-mingled.
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
seq_length = 7
|
||||
embedding_size = 5
|
||||
seq_tokens_with_embedding: list[list[int]] = []
|
||||
seq_embeds: list[Optional[torch.Tensor]] = []
|
||||
|
||||
seq_tokens_with_embedding.append(list(range(seq_length)))
|
||||
seq_embeds.append(None)
|
||||
seq_tokens_with_embedding.append([0] * seq_length)
|
||||
seq_embeds.append(torch.rand(embedding_size))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_with_embedding[i],
|
||||
prompt_embeds=seq_embeds[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_with_embedding))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Prefill scheduler budget num_batched_tokens
|
||||
# >= scheduler_config max_num_batched_tokens
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
seq_tokens_prefill_budget: list[list[int]] = []
|
||||
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
max_token_budget=8,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=5,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
seq_length = 4
|
||||
num_seqs = 3
|
||||
for i in range(num_seqs):
|
||||
seq_tokens_prefill_budget.append([i] * seq_length)
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_prefill_budget[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_prefill_budget))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(2))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Budget can not schedule in waiting
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
max_token_budget=30,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=30,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
seq_length = 16
|
||||
num_seqs = 3
|
||||
seq_tokens_prefill_budget_waiting: list[list[int]] = []
|
||||
|
||||
for i in range(num_seqs):
|
||||
seq_tokens_prefill_budget_waiting.append(list(range(seq_length)))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_prefill_budget_waiting[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_prefill_budget_waiting))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Sequence num_new_tokens > prompt_limit marked FINISHED_IGNORED
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=30,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
seq_length = 31
|
||||
seq_tokens_prompt_limit: list[list[int]] = []
|
||||
seq_tokens_prompt_limit.append(list(range(seq_length)))
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt("0",
|
||||
prompt_tokens=seq_tokens_prompt_limit[0],
|
||||
block_size=block_size)
|
||||
]
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(0))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Budget can not allocate, AllocStatus is NEVER marked FINISHED_IGNORED
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=160,
|
||||
num_gpu_blocks=160,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=320,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
seq_length = 320
|
||||
num_seqs = 1
|
||||
seq_tokens_never: list[list[int]] = []
|
||||
for i in range(num_seqs):
|
||||
seq_tokens_never.append(list(range(seq_length)))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_never[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_never))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(0))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
# Budget can not allocate, AllocStatus is LATER
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=160,
|
||||
num_gpu_blocks=160,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=320,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
seq_length = 160
|
||||
num_seqs = 2
|
||||
seq_tokens_later: list[list[int]] = []
|
||||
for i in range(num_seqs):
|
||||
seq_tokens_later.append(list(range(seq_length)))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens_later[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens_later))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
scheduler._schedule_default()
|
||||
seq_id_to_num_tokens_computed = (
|
||||
scheduler.block_manager._computed_blocks_tracker.
|
||||
_seq_id_to_num_tokens_computed.get(1))
|
||||
assert seq_id_to_num_tokens_computed is None
|
||||
|
||||
@ -270,6 +270,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
self.block_tables[seq_id].free()
|
||||
del self.block_tables[seq_id]
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
seq_id = seq.seq_id
|
||||
self._computed_blocks_tracker.remove_seq(seq_id)
|
||||
|
||||
def free_cross(self, seq_group: SequenceGroup) -> None:
|
||||
request_id = seq_group.request_id
|
||||
if request_id not in self.cross_block_tables:
|
||||
|
||||
@ -901,6 +901,8 @@ class Scheduler:
|
||||
num_new_tokens=num_new_tokens_uncached,
|
||||
num_new_seqs=num_new_seqs,
|
||||
):
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.SWAPPED)
|
||||
break
|
||||
|
||||
if lora_int_id > 0 and curr_loras is not None:
|
||||
@ -1024,6 +1026,9 @@ class Scheduler:
|
||||
# Put the sequence back into the waiting queue
|
||||
waiting_queue.appendleft(seq_group)
|
||||
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
|
||||
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
|
||||
|
||||
self.waiting = waiting_queue
|
||||
@ -1113,6 +1118,8 @@ class Scheduler:
|
||||
)
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.FINISHED_IGNORED)
|
||||
ignored_seq_groups.append(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
@ -1126,6 +1133,8 @@ class Scheduler:
|
||||
can_allocate = self.block_manager.can_allocate(
|
||||
seq_group, num_lookahead_slots=num_lookahead_slots)
|
||||
if can_allocate == AllocStatus.LATER:
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
@ -1136,6 +1145,8 @@ class Scheduler:
|
||||
)
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.FINISHED_IGNORED)
|
||||
ignored_seq_groups.append(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
@ -1145,6 +1156,8 @@ class Scheduler:
|
||||
if len(seq_groups) == 0:
|
||||
using_prompt_embeds = seq_group.uses_prompt_embeds()
|
||||
if using_prompt_embeds != seq_group.uses_prompt_embeds():
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
leftover_waiting_sequences.appendleft(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
@ -1159,6 +1172,8 @@ class Scheduler:
|
||||
and len(curr_loras) >= self.lora_config.max_loras):
|
||||
# We don't have a space for another LoRA, so
|
||||
# we ignore this request for now.
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
leftover_waiting_sequences.appendleft(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
@ -1168,6 +1183,8 @@ class Scheduler:
|
||||
# We've reached the budget limit - since there might be
|
||||
# continuous prefills in the running queue, we should break
|
||||
# to avoid scheduling any new prefills.
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
break
|
||||
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
@ -1175,6 +1192,8 @@ class Scheduler:
|
||||
num_new_tokens=num_new_tokens_uncached,
|
||||
num_new_seqs=num_new_seqs,
|
||||
):
|
||||
self.remove_seq_from_computed_blocks_tracker(
|
||||
seq_group, SequenceStatus.WAITING)
|
||||
break
|
||||
|
||||
# Can schedule this request.
|
||||
@ -1688,6 +1707,20 @@ class Scheduler:
|
||||
"""Free a sequence from a block table."""
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def remove_seq_from_computed_blocks_tracker(
|
||||
self, seq_group: SequenceGroup,
|
||||
status: Optional[SequenceStatus]) -> None:
|
||||
seqs = seq_group.get_seqs(status=status)
|
||||
for seq in seqs:
|
||||
self._remove_seq_from_computed_blocks_tracker(seq)
|
||||
|
||||
def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Free a sequence computed blocks tracker _seq_id_to_blocks_hashes
|
||||
and _seq_id_to_num_tokens_computed.
|
||||
"""
|
||||
self.block_manager.remove_seq_from_computed_blocks_tracker(seq)
|
||||
|
||||
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
|
||||
"""Free finished seqs in a sequence group."""
|
||||
for seq in seq_group.get_seqs():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user