diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d129c65f10131..5aa34040624f6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -15,6 +15,7 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.metrics.stats import SchedulerStats @@ -601,7 +602,7 @@ class Scheduler: # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. - stopped = self._check_stop(request) + stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) break @@ -645,25 +646,6 @@ class Scheduler: scheduler_stats=self.make_stats(), ) - def _check_stop(self, request: Request) -> bool: - if (request.num_tokens >= self.max_model_len - or request.num_output_tokens >= request.max_tokens): - request.status = RequestStatus.FINISHED_LENGTH_CAPPED - return True - - sampling_params = request.sampling_params - last_token_id = request.output_token_ids[-1] - if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): - request.status = RequestStatus.FINISHED_STOPPED - return True - - if last_token_id in (sampling_params.stop_token_ids or ()): - request.status = RequestStatus.FINISHED_STOPPED - request.stop_reason = last_token_id - return True - return False - def add_request(self, request: Request) -> None: self.waiting.append(request) self.requests[request.request_id] = request diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py new file mode 100644 index 0000000000000..3a0028a59016e --- /dev/null +++ b/vllm/v1/core/sched/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.request import Request, RequestStatus + + +def check_stop(request: Request, max_model_len: int) -> bool: + if (request.num_tokens >= max_model_len + or request.num_output_tokens >= request.max_tokens): + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + return True + + sampling_params = request.sampling_params + last_token_id = request.output_token_ids[-1] + if (not sampling_params.ignore_eos + and last_token_id == request.eos_token_id): + request.status = RequestStatus.FINISHED_STOPPED + return True + + if last_token_id in (sampling_params.stop_token_ids or ()): + request.status = RequestStatus.FINISHED_STOPPED + request.stop_reason = last_token_id + return True + return False