Add logging

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-12 22:09:30 -07:00
parent da07067215
commit e484ecb947
2 changed files with 26 additions and 25 deletions

View File

@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.engine import EngineCoreEvent, EngineCoreEventType
from vllm.v1.request import Request
def record_queued(request: Request) -> None:
request.events.append(EngineCoreEvent.new_event(
EngineCoreEventType.QUEUED))
def record_scheduled(request: Request, timestamp: float) -> None:
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED, timestamp))
def record_preempted(request: Request, timestamp: float) -> None:
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.PREEMPTED, timestamp))

View File

@ -15,10 +15,11 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.common import CommonSchedulerStates
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.logging import (record_preempted, record_queued,
record_scheduled)
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs)
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
@ -172,7 +173,8 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.request_preempted(preempted_req, scheduled_timestamp)
if self.log_stats:
record_preempted(preempted_req, scheduled_timestamp)
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
@ -314,7 +316,8 @@ class Scheduler(SchedulerInterface):
req_index += 1
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
self.request_scheduled(request, scheduled_timestamp)
if self.log_stats:
record_scheduled(request, scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
@ -614,7 +617,7 @@ class Scheduler(SchedulerInterface):
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
self.request_queued(request)
record_queued(request)
def finish_requests(
self,
@ -668,26 +671,6 @@ class Scheduler(SchedulerInterface):
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def request_queued(self, request: Request):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))
def request_scheduled(self, request: Request, timestamp: float):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp))
def request_preempted(self, request: Request, timestamp: float):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.PREEMPTED,
timestamp))
def make_stats(self) -> Optional[SchedulerStats]:
if not self.log_stats:
return None