Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-11 23:36:42 -07:00
parent 0bf6e97493
commit f198d7d07a
2 changed files with 24 additions and 20 deletions

View File

@ -15,6 +15,7 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs) EngineCoreOutput, EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
@ -601,7 +602,7 @@ class Scheduler:
# Check for stop and update request state. # Check for stop and update request state.
# This must be called before we make the EngineCoreOutput. # This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request) stopped = check_stop(request, self.max_model_len)
if stopped: if stopped:
self._free_request(request) self._free_request(request)
break break
@ -645,25 +646,6 @@ class Scheduler:
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(),
) )
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False