[V1][Core][1/n] Logging and Metrics (#11962)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-01-12 16:02:02 -05:00 committed by GitHub
parent 263a870ee1
commit 9597a095f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 129 additions and 84 deletions

View File

@ -80,7 +80,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 4 assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step()) > 0: while len(engine_core.step().outputs) > 0:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
@ -170,7 +170,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step()) > 0: while len(engine_core.step().outputs) > 0:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0

View File

@ -43,7 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
def loop_until_done(client: EngineCoreClient, outputs: Dict): def loop_until_done(client: EngineCoreClient, outputs: Dict):
while True: while True:
engine_core_outputs = client.get_output() engine_core_outputs = client.get_output().outputs
if len(engine_core_outputs) == 0: if len(engine_core_outputs) == 0:
break break
@ -61,7 +61,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True: while True:
engine_core_outputs = await client.get_output_async() engine_core_outputs = await client.get_output_async().outputs
if len(engine_core_outputs) == 0: if len(engine_core_outputs) == 0:
break break

View File

@ -8,7 +8,8 @@ from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -394,12 +395,12 @@ class Scheduler:
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput", model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]: ) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding. # NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = [] new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = [] outputs: List[EngineCoreOutput] = []
for request in self.running: for request in self.running:
req_id = request.request_id req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id] request.num_computed_tokens += num_scheduled_tokens[req_id]
@ -438,7 +439,7 @@ class Scheduler:
finished=request.is_finished(), finished=request.is_finished(),
finish_reason=request.get_finished_reason(), finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason) stop_reason=request.stop_reason)
engine_core_outputs.append(output) outputs.append(output)
# Breakout of the loop. # Breakout of the loop.
if stopped: if stopped:
@ -446,7 +447,10 @@ class Scheduler:
new_running.append(request) new_running.append(request)
self.running = new_running self.running = new_running
return engine_core_outputs return EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)
def _check_stop(self, request: Request) -> bool: def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len if (request.num_tokens >= self.max_model_len
@ -515,6 +519,12 @@ class Scheduler:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0 return self.get_num_unfinished_requests() > 0
def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
)
@dataclass @dataclass
class NewRequestData: class NewRequestData:

View File

@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional, Union
import msgspec import msgspec
from vllm.v1.metrics.stats import SchedulerStats
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
@ -56,6 +58,7 @@ class EngineCoreOutputs(
# [num_reqs] # [num_reqs]
outputs: List[EngineCoreOutput] outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats
@dataclass @dataclass

View File

@ -4,7 +4,6 @@ from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
@ -22,6 +21,8 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase
from vllm.v1.metrics.stats import SchedulerStats
logger = init_logger(__name__) logger = init_logger(__name__)
@ -34,7 +35,6 @@ class AsyncLLM(EngineClient):
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
log_requests: bool = True, log_requests: bool = True,
@ -45,7 +45,10 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats self.log_stats = log_stats
self.stat_loggers = stat_loggers self.stat_loggers: List[StatLoggerBase] = [
LoggingStatLogger(),
# TODO(rob): PrometheusStatLogger(),
]
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
@ -82,7 +85,6 @@ class AsyncLLM(EngineClient):
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats,
) )
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
@ -94,7 +96,6 @@ class AsyncLLM(EngineClient):
engine_config: Optional[VllmConfig] = None, engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLM": ) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs.""" """Create an AsyncLLM from the EngineArgs."""
@ -114,7 +115,6 @@ class AsyncLLM(EngineClient):
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers,
) )
def shutdown(self): def shutdown(self):
@ -254,7 +254,8 @@ class AsyncLLM(EngineClient):
outputs = await self.engine_core.get_output_async() outputs = await self.engine_core.get_output_async()
# 2) Detokenize based on the output. # 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs) request_outputs, reqs_to_abort = self.detokenizer.step(
outputs.outputs)
# 3) Put the RequestOutputs into the per-request queues. # 3) Put the RequestOutputs into the per-request queues.
self._process_request_outputs(request_outputs) self._process_request_outputs(request_outputs)
@ -262,6 +263,9 @@ class AsyncLLM(EngineClient):
# 4) Abort any requests that finished due to stop strings. # 4) Abort any requests that finished due to stop strings.
await self.engine_core.abort_requests_async(reqs_to_abort) await self.engine_core.abort_requests_async(reqs_to_abort)
# 5) Log any stats.
await self._log_stats(scheduler_stats=outputs.scheduler_stats)
except Exception as e: except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e) logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
@ -278,6 +282,14 @@ class AsyncLLM(EngineClient):
if request_id in self.rid_to_queue: if request_id in self.rid_to_queue:
del self.rid_to_queue[request_id] del self.rid_to_queue[request_id]
async def _log_stats(self, scheduler_stats: SchedulerStats):
"""Log stats to the stat loggers."""
if not self.log_stats:
return
for logger in self.stat_loggers:
logger.log(scheduler_stats=scheduler_stats)
def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType,

View File

@ -17,9 +17,9 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreProfile, EngineCoreRequest, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestType, EngineCoreRequestUnion) EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -28,9 +28,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = 2.5
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 5
class EngineCore: class EngineCore:
@ -40,10 +38,8 @@ class EngineCore:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool = False,
): ):
assert vllm_config.model_config.runner_type != "pooling" assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats
logger.info("Initializing an LLM engine (v%s) with config: %s", logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config) VLLM_VERSION, vllm_config)
@ -62,8 +58,6 @@ class EngineCore:
vllm_config.cache_config, vllm_config.cache_config,
vllm_config.lora_config) vllm_config.lora_config)
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer( self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config) vllm_config.model_config)
@ -114,11 +108,12 @@ class EngineCore:
self.scheduler.finish_requests(request_ids, self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED) RequestStatus.FINISHED_ABORTED)
def step(self) -> List[EngineCoreOutput]: def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output.""" """Schedule, execute, and make output."""
if not self.scheduler.has_unfinished_requests(): if not self.scheduler.has_unfinished_requests():
return [] return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output) output = self.model_executor.execute_model(scheduler_output)
@ -145,7 +140,9 @@ class EngineCoreProc(EngineCore):
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool = False, log_stats: bool = False,
): ):
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class)
self.log_stats = log_stats
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL, # overlap ZMQ socket IO with GPU since they release the GIL,
@ -153,7 +150,7 @@ class EngineCoreProc(EngineCore):
# model forward pass. # model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket, threading.Thread(target=self.process_input_socket,
args=(input_path, ), args=(input_path, ),
daemon=True).start() daemon=True).start()
@ -217,8 +214,10 @@ class EngineCoreProc(EngineCore):
self._handle_client_request(req) self._handle_client_request(req)
break break
except queue.Empty: except queue.Empty:
self._log_stats()
logger.debug("EngineCore busy loop waiting.") logger.debug("EngineCore busy loop waiting.")
# Break out the loop so we can log_stats in step().
if self.log_stats:
break
except BaseException: except BaseException:
raise raise
@ -230,28 +229,9 @@ class EngineCoreProc(EngineCore):
# 3) Step the engine core. # 3) Step the engine core.
outputs = self.step() outputs = self.step()
# 4) Put EngineCoreOutputs into the output queue. # 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs) self.output_queue.put_nowait(outputs)
self._log_stats()
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""
if not self.log_stats:
return
now = time.time()
if now - self._last_logging_time > LOGGING_TIME_S:
logger.info(
"RUNNING: %s | WAITING: %s",
len(self.scheduler.running),
len(self.scheduler.waiting),
)
self._last_logging_time = now
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client.""" """Handle EngineCoreRequest or EngineCoreABORT from Client."""
@ -301,7 +281,6 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True: while True:
engine_core_outputs = self.output_queue.get() outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
encoder.encode_into(outputs, buffer) encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False) socket.send_multipart((buffer, ), copy=False)

View File

@ -12,9 +12,9 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreProfile, EngineCoreRequest, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestType, EngineCoreRequestUnion) EngineCoreRequestUnion)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
@ -40,7 +40,6 @@ class EngineCoreClient(ABC):
asyncio_mode: bool, asyncio_mode: bool,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool = False,
) -> "EngineCoreClient": ) -> "EngineCoreClient":
# TODO: support this for debugging purposes. # TODO: support this for debugging purposes.
@ -50,18 +49,18 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
return AsyncMPClient(vllm_config, executor_class, log_stats) return AsyncMPClient(vllm_config, executor_class)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats) return SyncMPClient(vllm_config, executor_class)
return InprocClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class)
@abstractmethod @abstractmethod
def shutdown(self): def shutdown(self):
... ...
def get_output(self) -> List[EngineCoreOutput]: def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
@ -73,7 +72,7 @@ class EngineCoreClient(ABC):
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
async def get_output_async(self) -> List[EngineCoreOutput]: async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
@ -99,7 +98,7 @@ class InprocClient(EngineCoreClient):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs) self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> List[EngineCoreOutput]: def get_output(self) -> EngineCoreOutputs:
return self.engine_core.step() return self.engine_core.step()
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
@ -133,7 +132,7 @@ class MPClient(EngineCoreClient):
asyncio_mode: bool, asyncio_mode: bool,
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: Type[Executor], executor_class: Type[Executor],
log_stats: bool = False, log_stats: bool,
): ):
# The child processes will send SIGUSR1 when unrecoverable # The child processes will send SIGUSR1 when unrecoverable
# errors happen. We kill the process tree here so that the # errors happen. We kill the process tree here so that the
@ -194,22 +193,19 @@ class MPClient(EngineCoreClient):
class SyncMPClient(MPClient): class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore.""" """Synchronous client for multi-proc EngineCore."""
def __init__(self, def __init__(self, vllm_config: VllmConfig,
vllm_config: VllmConfig, executor_class: Type[Executor]):
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__( super().__init__(
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=log_stats, log_stats=False,
) )
def get_output(self) -> List[EngineCoreOutput]: def get_output(self) -> EngineCoreOutputs:
(frame, ) = self.output_socket.recv_multipart(copy=False) (frame, ) = self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frame.buffer).outputs return self.decoder.decode(frame.buffer)
return engine_core_outputs
def _send_input(self, request_type: EngineCoreRequestType, def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None: request: EngineCoreRequestUnion) -> None:
@ -235,23 +231,19 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, def __init__(self, vllm_config: VllmConfig,
vllm_config: VllmConfig, executor_class: Type[Executor]):
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__( super().__init__(
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=log_stats, log_stats=True,
) )
async def get_output_async(self) -> List[EngineCoreOutput]: async def get_output_async(self) -> EngineCoreOutputs:
frames = await self.output_socket.recv_multipart(copy=False) frames = await self.output_socket.recv_multipart(copy=False)
engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs return self.decoder.decode(frames[0].buffer)
return engine_core_outputs
async def _send_input(self, request_type: EngineCoreRequestType, async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None: request: EngineCoreRequestUnion) -> None:

View File

@ -74,7 +74,6 @@ class LLMEngine:
asyncio_mode=False, asyncio_mode=False,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=False,
) )
@classmethod @classmethod
@ -147,11 +146,11 @@ class LLMEngine:
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
# 1) Get EngineCoreOutput from the EngineCore. # 1) Get EngineCoreOutput from the EngineCore.
engine_core_outputs = self.engine_core.get_output() outputs = self.engine_core.get_output()
# 2) Detokenizer the EngineCoreOutput. # 2) Detokenizer the EngineCoreOutput.
request_outputs, requests_to_abort = self.detokenizer.step( request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs) outputs.outputs)
# 3) Abort requests that finished due to stopping criteria. # 3) Abort requests that finished due to stopping criteria.
if requests_to_abort: if requests_to_abort:

View File

View File

@ -0,0 +1,38 @@
import time
from abc import ABC, abstractmethod
from vllm.logger import init_logger
from vllm.v1.metrics.stats import SchedulerStats
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
class StatLoggerBase(ABC):
@abstractmethod
def log(self, scheduler_stats: SchedulerStats):
...
class LoggingStatLogger(StatLoggerBase):
def __init__(self):
self.last_log_time = time.monotonic()
def log(self, scheduler_stats: SchedulerStats):
"""Log Stats to standard output."""
# Log every _LOCAL_LOGGING_INTERVAL_SEC.
now = time.monotonic()
if now - self.last_log_time < _LOCAL_LOGGING_INTERVAL_SEC:
return
self.last_log_time = now
# Format and print output.
logger.info(
"Running: %d reqs, Waiting: %d reqs ",
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
)

12
vllm/v1/metrics/stats.py Normal file
View File

@ -0,0 +1,12 @@
from dataclasses import dataclass
@dataclass
class SchedulerStats:
"""Stats associated with the scheduler."""
num_running_reqs: int = 0
num_waiting_reqs: int = 0
# gpu_cache_usage: float = 0.0
# gpu_prefix_cache_hit_rate: float = 0.0