[Core] Enable command line logging for LLMEngine (#25610)

Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-09-25 15:31:17 -07:00 committed by GitHub
parent e71b8e210d
commit 8c435c9bce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 9 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy from copy import copy
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -31,8 +32,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
@ -74,9 +74,6 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.log_stats = log_stats self.log_stats = log_stats
self.stat_logger: Optional[StatLoggerBase] = None
if self.log_stats:
self.stat_logger = PrometheusStatLogger(vllm_config)
executor_backend = ( executor_backend = (
self.vllm_config.parallel_config.distributed_executor_backend) self.vllm_config.parallel_config.distributed_executor_backend)
@ -122,6 +119,15 @@ class LLMEngine:
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.logger_manager: Optional[StatLoggerManager] = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
)
self.logger_manager.log_engine_initialized()
if not multiprocess_mode: if not multiprocess_mode:
# for v0 compatibility # for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@ -269,10 +275,13 @@ class LLMEngine:
self.engine_core.abort_requests(processed_outputs.reqs_to_abort) self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats # 4) Record stats
if self.stat_logger is not None: if self.logger_manager is not None:
assert outputs.scheduler_stats is not None assert outputs.scheduler_stats is not None
self.stat_logger.record(scheduler_stats=outputs.scheduler_stats, self.logger_manager.record(
iteration_stats=iteration_stats) scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
self.do_log_stats_with_interval()
return processed_outputs.request_outputs return processed_outputs.request_outputs
@ -315,6 +324,20 @@ class LLMEngine:
return self.tokenizer return self.tokenizer
def do_log_stats(self) -> None:
"""Log stats if logging is enabled."""
if self.logger_manager:
self.logger_manager.log()
def do_log_stats_with_interval(self) -> None:
"""Log stats when the time interval has passed."""
now = time.time()
if not hasattr(self, "_last_log_time"):
self._last_log_time = now
if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
self.do_log_stats()
self._last_log_time = now
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
return self.engine_core.add_lora(lora_request) return self.engine_core.add_lora(lora_request)

View File

@ -90,7 +90,6 @@ class LoggingStatLogger(StatLoggerBase):
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_idx: int = 0): engine_idx: int = 0):
"""Log Stats to standard output.""" """Log Stats to standard output."""
if iteration_stats: if iteration_stats:
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)