[Bugfix][Core] fix abort_seq_group and memory leak when n>1 (#14326)

Signed-off-by: courage17340 <courage17340@163.com>
This commit is contained in:
courage17340 2025-03-06 23:59:32 +08:00 committed by GitHub
parent 6bd1dd9d26
commit caac5c2e59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 10 deletions

View File

@ -16,8 +16,9 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStage, SequenceStatus)
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupMetadataDelta, SequenceStage,
SequenceStatus)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
@ -561,7 +562,11 @@ class Scheduler:
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
def abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
@ -573,21 +578,29 @@ class Scheduler:
Args:
request_id: The ID(s) of the sequence group to abort.
seq_id_to_seq_group: helper for groups with n>1
"""
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
if not request_ids:
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity.
break
if seq_group.request_id in request_ids:
# When n>1, seq_group.request_id looks like
# foo_parallel_sample_0, while request_ids is just foo, and we
# should resolve it as real_request_id to match.
if seq_group.request_id in seq_id_to_seq_group:
real_request_id = seq_id_to_seq_group[
seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
request_ids.remove(seq_group.request_id)
# We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
@ -598,6 +611,8 @@ class Scheduler:
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]
self._free_seq_group_cross_attn_blocks(aborted_group)

View File

@ -887,7 +887,8 @@ class LLMEngine:
>>> engine.abort_request(request_id)
"""
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
@ -1354,6 +1355,11 @@ class LLMEngine:
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# When n>1, elements in self.seq_id_to_seq_group should be deleted
# here, otherwise memory leaks.
for finished_request_id in finished_requests_ids:
if finished_request_id in self.seq_id_to_seq_group:
del self.seq_id_to_seq_group[finished_request_id]
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0: