[Bugfix][Core] Prefix caching causes incorrect outputs due to outdated ComputedBlocksTracker (#18957)

Signed-off-by: 刘全 <quan.liu2@dbappsecurity.com.cn>
Co-authored-by: 刘全 <quan.liu2@dbappsecurity.com.cn>
This commit is contained in:
quanliu 2025-06-16 12:56:37 +08:00 committed by GitHub
parent c6703d1e0d
commit 92183b41f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 331 additions and 0 deletions

View File

@ -1041,3 +1041,297 @@ def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
for seq in scheduled_seq_group.seq_group.seqs:
seq.status = SequenceStatus.FINISHED_STOPPED
scheduler.free_finished_seq_groups()
def test_remove_seq_from_computed_blocks_tracker():
"""
Test that computed_blocks_tracker correctly removes stale sequences
during scheduling.
The test covers 9 scheduling branches where stale seqs are removed:
- 1 in _schedule_swapped
- 1 in _schedule_priority_preemption
- 7 in _schedule_prefill
Each branch is tested to ensure proper cleanup of
_seq_id_to_num_tokens_computed.
"""
# Budget can not schedule in swapped
block_size = 2
max_seq_group = 3
seq_tokens_with_swapped: list[list[int]] = []
blocks_to_swap_out: list[tuple[int, int]] = []
curr_loras: set[int] = set()
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
enable_prefix_caching=True,
)
budget = create_token_budget(token_budget=15)
seq_length = 16
num_seqs = 3
for i in range(num_seqs):
seq_tokens_with_swapped.append([i] * seq_length)
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_with_swapped[i],
block_size=block_size)
for i in range(len(seq_tokens_with_swapped))
]
for _, seq_group in seq_and_seq_groups:
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
scheduler._add_seq_group_to_swapped(seq_group)
scheduler._schedule_swapped(budget, curr_loras)
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None
# Prefill schedule don't have a space for another LoRA, so
# we ignore this request for now.
block_size = 4
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config,
block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64,
enable_prefix_caching=True)
budget = create_token_budget(token_budget=120)
num_seqs = 2
for i in range(num_seqs):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=seq_length,
block_size=block_size,
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_path="abc"))
scheduler.add_seq_group(seq_group)
scheduler._schedule_prefills(budget, curr_loras)
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None
# Priority preemption schedule
scheduler._schedule_priority_preemption(budget)
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None
# Prefill scheduler does not schedule batches with prompt tokens and
# prompt embeddings co-mingled.
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
)
seq_length = 7
embedding_size = 5
seq_tokens_with_embedding: list[list[int]] = []
seq_embeds: list[Optional[torch.Tensor]] = []
seq_tokens_with_embedding.append(list(range(seq_length)))
seq_embeds.append(None)
seq_tokens_with_embedding.append([0] * seq_length)
seq_embeds.append(torch.rand(embedding_size))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_with_embedding[i],
prompt_embeds=seq_embeds[i],
block_size=block_size)
for i in range(len(seq_tokens_with_embedding))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None
# Prefill scheduler budget num_batched_tokens
# >= scheduler_config max_num_batched_tokens
block_size = 2
max_seq_group = 3
seq_tokens_prefill_budget: list[list[int]] = []
scheduler = initialize_scheduler(
block_size=block_size,
max_token_budget=8,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=5,
enable_prefix_caching=True,
)
seq_length = 4
num_seqs = 3
for i in range(num_seqs):
seq_tokens_prefill_budget.append([i] * seq_length)
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_prefill_budget[i],
block_size=block_size)
for i in range(len(seq_tokens_prefill_budget))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(2))
assert seq_id_to_num_tokens_computed is None
# Budget can not schedule in waiting
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
max_token_budget=30,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=30,
enable_prefix_caching=True,
)
seq_length = 16
num_seqs = 3
seq_tokens_prefill_budget_waiting: list[list[int]] = []
for i in range(num_seqs):
seq_tokens_prefill_budget_waiting.append(list(range(seq_length)))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_prefill_budget_waiting[i],
block_size=block_size)
for i in range(len(seq_tokens_prefill_budget_waiting))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None
# Sequence num_new_tokens > prompt_limit marked FINISHED_IGNORED
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=30,
enable_prefix_caching=True,
)
seq_length = 31
seq_tokens_prompt_limit: list[list[int]] = []
seq_tokens_prompt_limit.append(list(range(seq_length)))
seq_and_seq_groups = [
create_dummy_prompt("0",
prompt_tokens=seq_tokens_prompt_limit[0],
block_size=block_size)
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(0))
assert seq_id_to_num_tokens_computed is None
# Budget can not allocate, AllocStatus is NEVER marked FINISHED_IGNORED
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=160,
num_gpu_blocks=160,
max_num_seqs=max_seq_group,
max_model_len=320,
enable_prefix_caching=True,
)
seq_length = 320
num_seqs = 1
seq_tokens_never: list[list[int]] = []
for i in range(num_seqs):
seq_tokens_never.append(list(range(seq_length)))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_never[i],
block_size=block_size)
for i in range(len(seq_tokens_never))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(0))
assert seq_id_to_num_tokens_computed is None
# Budget can not allocate, AllocStatus is LATER
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=160,
num_gpu_blocks=160,
max_num_seqs=max_seq_group,
max_model_len=320,
enable_prefix_caching=True,
)
seq_length = 160
num_seqs = 2
seq_tokens_later: list[list[int]] = []
for i in range(num_seqs):
seq_tokens_later.append(list(range(seq_length)))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens_later[i],
block_size=block_size)
for i in range(len(seq_tokens_later))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
scheduler._schedule_default()
seq_id_to_num_tokens_computed = (
scheduler.block_manager._computed_blocks_tracker.
_seq_id_to_num_tokens_computed.get(1))
assert seq_id_to_num_tokens_computed is None

View File

@ -270,6 +270,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.block_tables[seq_id].free()
del self.block_tables[seq_id]
def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
seq_id = seq.seq_id
self._computed_blocks_tracker.remove_seq(seq_id)
def free_cross(self, seq_group: SequenceGroup) -> None:
request_id = seq_group.request_id
if request_id not in self.cross_block_tables:

View File

@ -901,6 +901,8 @@ class Scheduler:
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.SWAPPED)
break
if lora_int_id > 0 and curr_loras is not None:
@ -1024,6 +1026,9 @@ class Scheduler:
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
self.waiting = waiting_queue
@ -1113,6 +1118,8 @@ class Scheduler:
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.FINISHED_IGNORED)
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
@ -1126,6 +1133,8 @@ class Scheduler:
can_allocate = self.block_manager.can_allocate(
seq_group, num_lookahead_slots=num_lookahead_slots)
if can_allocate == AllocStatus.LATER:
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
@ -1136,6 +1145,8 @@ class Scheduler:
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.FINISHED_IGNORED)
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
@ -1145,6 +1156,8 @@ class Scheduler:
if len(seq_groups) == 0:
using_prompt_embeds = seq_group.uses_prompt_embeds()
if using_prompt_embeds != seq_group.uses_prompt_embeds():
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
@ -1159,6 +1172,8 @@ class Scheduler:
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
@ -1168,6 +1183,8 @@ class Scheduler:
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
num_new_seqs = seq_group.get_max_num_running_seqs()
@ -1175,6 +1192,8 @@ class Scheduler:
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
self.remove_seq_from_computed_blocks_tracker(
seq_group, SequenceStatus.WAITING)
break
# Can schedule this request.
@ -1688,6 +1707,20 @@ class Scheduler:
"""Free a sequence from a block table."""
self.block_manager.free(seq)
def remove_seq_from_computed_blocks_tracker(
self, seq_group: SequenceGroup,
status: Optional[SequenceStatus]) -> None:
seqs = seq_group.get_seqs(status=status)
for seq in seqs:
self._remove_seq_from_computed_blocks_tracker(seq)
def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None:
"""
Free a sequence computed blocks tracker _seq_id_to_blocks_hashes
and _seq_id_to_num_tokens_computed.
"""
self.block_manager.remove_seq_from_computed_blocks_tracker(seq)
def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
"""Free finished seqs in a sequence group."""
for seq in seq_group.get_seqs():