mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:44:59 +08:00
[Core] Optimize scheduler request removal for single completions (#21917)
Signed-off-by: chiliu <chiliu@paypal.com> Signed-off-by: chiliu <cliu_whu@yeah.net> Co-authored-by: chiliu <chiliu@paypal.com>
This commit is contained in:
parent
c32e6ad1f6
commit
0167efe20d
@ -25,7 +25,7 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@ -872,9 +872,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Remove the stopped requests from the running and waiting queues.
|
||||
if stopped_running_reqs:
|
||||
self.running = [
|
||||
req for req in self.running if req not in stopped_running_reqs
|
||||
]
|
||||
self.running = remove_all(self.running, stopped_running_reqs)
|
||||
if stopped_preempted_reqs:
|
||||
# This is a rare case and unlikely to impact performance.
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
@ -1000,7 +998,7 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
running_requests_to_remove = []
|
||||
running_requests_to_remove = set()
|
||||
waiting_requests_to_remove = []
|
||||
valid_requests = []
|
||||
|
||||
@ -1013,13 +1011,13 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
valid_requests.append(request)
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
running_requests_to_remove.append(request)
|
||||
running_requests_to_remove.add(request)
|
||||
else:
|
||||
waiting_requests_to_remove.append(request)
|
||||
|
||||
# Remove all requests from queues at once for better efficiency
|
||||
for request in running_requests_to_remove:
|
||||
self.running.remove(request)
|
||||
if running_requests_to_remove:
|
||||
self.running = remove_all(self.running, running_requests_to_remove)
|
||||
if waiting_requests_to_remove:
|
||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -7,6 +8,38 @@ import torch
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
This method optimizes for the common case of removing a single item,
|
||||
falling back to list comprehension for multiple items.
|
||||
|
||||
Args:
|
||||
lst: The list to remove items from
|
||||
items_to_remove: Set of items to remove
|
||||
|
||||
Returns:
|
||||
Either the modified original list (for single item removal) or
|
||||
a new list (for multiple item removal). Callers should use the
|
||||
returned value.
|
||||
|
||||
Note:
|
||||
For single item removal, this modifies the original list in-place
|
||||
and returns it. For multiple items, it creates and returns a new list.
|
||||
"""
|
||||
if not items_to_remove:
|
||||
return lst
|
||||
|
||||
if len(items_to_remove) == 1:
|
||||
# Fast path for single item removal (most common case)
|
||||
item = next(iter(items_to_remove))
|
||||
with contextlib.suppress(ValueError):
|
||||
lst.remove(item)
|
||||
return lst
|
||||
# For multiple items, use list comprehension
|
||||
return [item for item in lst if item not in items_to_remove]
|
||||
|
||||
|
||||
def check_stop(request: Request,
|
||||
max_model_len: int,
|
||||
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user