vllm/vllm/v1/core/sched/utils.py
633WHU 0167efe20d
[Core] Optimize scheduler request removal for single completions (#21917)
Signed-off-by: chiliu <chiliu@paypal.com>
Signed-off-by: chiliu <cliu_whu@yeah.net>
Co-authored-by: chiliu <chiliu@paypal.com>
2025-08-19 18:25:59 -07:00

70 lines
2.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import Optional
import torch
from vllm.v1.request import Request, RequestStatus
def remove_all(lst: list, items_to_remove: set) -> list:
"""Remove all items from a list that are in the items_to_remove set.
This method optimizes for the common case of removing a single item,
falling back to list comprehension for multiple items.
Args:
lst: The list to remove items from
items_to_remove: Set of items to remove
Returns:
Either the modified original list (for single item removal) or
a new list (for multiple item removal). Callers should use the
returned value.
Note:
For single item removal, this modifies the original list in-place
and returns it. For multiple items, it creates and returns a new list.
"""
if not items_to_remove:
return lst
if len(items_to_remove) == 1:
# Fast path for single item removal (most common case)
item = next(iter(items_to_remove))
with contextlib.suppress(ValueError):
lst.remove(item)
return lst
# For multiple items, use list comprehension
return [item for item in lst if item not in items_to_remove]
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False