mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:55:27 +08:00
[Scheduer] Simplify stop checking for pooling models (#30591)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
b09806e28f
commit
1cec5b7ea9
@ -1117,6 +1117,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
stopped = False
|
stopped = False
|
||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
new_token_ids = generated_token_ids
|
new_token_ids = generated_token_ids
|
||||||
|
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
|
||||||
kv_transfer_params = None
|
kv_transfer_params = None
|
||||||
status_before_stop = request.status
|
status_before_stop = request.status
|
||||||
|
|
||||||
@ -1125,12 +1126,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_token_ids, stopped = self._update_request_with_output(
|
new_token_ids, stopped = self._update_request_with_output(
|
||||||
request, new_token_ids
|
request, new_token_ids
|
||||||
)
|
)
|
||||||
|
elif request.pooling_params and pooler_output is not None:
|
||||||
# Stop checking for pooler models.
|
# Pooling stops as soon as there is output.
|
||||||
pooler_output = None
|
request.status = RequestStatus.FINISHED_STOPPED
|
||||||
if pooler_outputs:
|
stopped = True
|
||||||
pooler_output = pooler_outputs[req_index]
|
|
||||||
stopped = check_stop(request, self.max_model_len, pooler_output)
|
|
||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
kv_transfer_params = self._free_request(request)
|
kv_transfer_params = self._free_request(request)
|
||||||
|
|||||||
@ -2,8 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
@ -39,14 +37,8 @@ def remove_all(lst: list, items_to_remove: set) -> list:
|
|||||||
return [item for item in lst if item not in items_to_remove]
|
return [item for item in lst if item not in items_to_remove]
|
||||||
|
|
||||||
|
|
||||||
def check_stop(
|
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||||
request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None
|
assert not request.pooling_params
|
||||||
) -> bool:
|
|
||||||
if request.pooling_params:
|
|
||||||
if pooler_output is not None:
|
|
||||||
request.status = RequestStatus.FINISHED_STOPPED
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
assert sampling_params is not None
|
assert sampling_params is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user