mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 06:18:44 +08:00
[Bugfix][Core] fix abort_seq_group and memory leak when n>1 (#14326)
Signed-off-by: courage17340 <courage17340@163.com>
This commit is contained in:
parent
6bd1dd9d26
commit
caac5c2e59
@ -16,8 +16,9 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||
SequenceStage, SequenceStatus)
|
||||
SequenceGroupBase, SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta, SequenceStage,
|
||||
SequenceStatus)
|
||||
from vllm.utils import Device, PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -561,7 +562,11 @@ class Scheduler:
|
||||
# Only for testing purposes.
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
def abort_seq_group(
|
||||
self,
|
||||
request_id: Union[str, Iterable[str]],
|
||||
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
|
||||
) -> None:
|
||||
"""Aborts a sequence group with the given ID.
|
||||
|
||||
Check if the sequence group with the given ID
|
||||
@ -573,21 +578,29 @@ class Scheduler:
|
||||
|
||||
Args:
|
||||
request_id: The ID(s) of the sequence group to abort.
|
||||
seq_id_to_seq_group: helper for groups with n>1
|
||||
"""
|
||||
if isinstance(request_id, str):
|
||||
request_id = (request_id, )
|
||||
request_ids = set(request_id)
|
||||
seq_id_to_seq_group = seq_id_to_seq_group or {}
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
aborted_groups: List[SequenceGroup] = []
|
||||
for seq_group in state_queue:
|
||||
if not request_ids:
|
||||
# Using 'break' here may add two extra iterations,
|
||||
# but is acceptable to reduce complexity.
|
||||
break
|
||||
if seq_group.request_id in request_ids:
|
||||
# When n>1, seq_group.request_id looks like
|
||||
# foo_parallel_sample_0, while request_ids is just foo, and we
|
||||
# should resolve it as real_request_id to match.
|
||||
if seq_group.request_id in seq_id_to_seq_group:
|
||||
real_request_id = seq_id_to_seq_group[
|
||||
seq_group.request_id].group_id
|
||||
else:
|
||||
real_request_id = seq_group.request_id
|
||||
if real_request_id in request_ids:
|
||||
# Appending aborted group into pending list.
|
||||
aborted_groups.append(seq_group)
|
||||
request_ids.remove(seq_group.request_id)
|
||||
# We can't remove real_request_id in request_ids here,
|
||||
# because there may be other seq groups sharing the same
|
||||
# real_request_id
|
||||
for aborted_group in aborted_groups:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(aborted_group)
|
||||
@ -598,6 +611,8 @@ class Scheduler:
|
||||
continue
|
||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||
self.free_seq(seq)
|
||||
if aborted_group.request_id in seq_id_to_seq_group:
|
||||
del seq_id_to_seq_group[aborted_group.request_id]
|
||||
|
||||
self._free_seq_group_cross_attn_blocks(aborted_group)
|
||||
|
||||
|
||||
@ -887,7 +887,8 @@ class LLMEngine:
|
||||
>>> engine.abort_request(request_id)
|
||||
"""
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.abort_seq_group(request_id)
|
||||
scheduler.abort_seq_group(
|
||||
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
@ -1354,6 +1355,11 @@ class LLMEngine:
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
# When n>1, elements in self.seq_id_to_seq_group should be deleted
|
||||
# here, otherwise memory leaks.
|
||||
for finished_request_id in finished_requests_ids:
|
||||
if finished_request_id in self.seq_id_to_seq_group:
|
||||
del self.seq_id_to_seq_group[finished_request_id]
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user