Merge pull request #19 from robertgshaw2-redhat/fix-prometheus-logging

Improve code structure
This commit is contained in:
Robert Shaw 2025-07-20 12:53:23 -04:00 committed by GitHub
commit 5e6114df5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 63 deletions

View File

@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase, from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
StatLoggerFactory, setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__) logger = init_logger(__name__)
@ -103,12 +102,11 @@ class AsyncLLM(EngineClient):
engine_idxs = [ engine_idxs = [
idx for idx in range(start_idx, start_idx + local_engines) idx for idx in range(start_idx, start_idx + local_engines)
] ]
self.stat_loggers = setup_default_loggers( self.logger_manager = StatLoggerManager(
vllm_config=vllm_config, vllm_config=vllm_config,
log_stats=self.log_stats,
engine_idxs=engine_idxs, engine_idxs=engine_idxs,
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
) ) if self.log_stats else None
# Tokenizer (+ ensure liveness if running in another process). # Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
@ -136,12 +134,8 @@ class AsyncLLM(EngineClient):
client_addresses=client_addresses, client_addresses=client_addresses,
client_index=client_index, client_index=client_index,
) )
if self.stat_loggers: if self.logger_manager:
# loggers, prom_logger self.logger_manager.log_engine_initialized()
loggers, _ = self.stat_loggers
for per_engine_loggers in loggers.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
try: try:
# Start output handler eagerly if we are in the asyncio eventloop. # Start output handler eagerly if we are in the asyncio eventloop.
@ -380,7 +374,7 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core engine_core = self.engine_core
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None logger_manager = self.logger_manager
async def output_handler(): async def output_handler():
try: try:
@ -420,12 +414,12 @@ class AsyncLLM(EngineClient):
# 4) Logging. # 4) Logging.
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if stat_loggers: # NOTE: we do not use self.log
AsyncLLM._record_stats( if logger_manager:
stat_loggers, logger_manager.record(
outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
engine_idx=outputs.engine_index,
) )
except Exception as e: except Exception as e:
logger.exception("AsyncLLM output_handler failed.") logger.exception("AsyncLLM output_handler failed.")
@ -442,26 +436,6 @@ class AsyncLLM(EngineClient):
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request_id) logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: tuple[dict[int, list[StatLoggerBase]],
PrometheusStatLogger],
engine_idx: int,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
per_engine_loggers, prom_logger = stat_loggers
for stat_logger in per_engine_loggers[engine_idx]:
stat_logger.record(engine_idx=engine_idx,
scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
prom_logger.record(engine_idx=engine_idx,
scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType,

View File

@ -600,32 +600,69 @@ def build_1_2_5_buckets(max_value: int) -> list[int]:
return build_buckets([1, 2, 5], max_value) return build_buckets([1, 2, 5], max_value)
def setup_default_loggers( class StatLoggerManager:
vllm_config: VllmConfig, """
log_stats: bool, StatLoggerManager:
engine_idxs: list[int], Logging happens at the level of the EngineCore (per scheduler).
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, * DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
) -> Optional[tuple[dict[int, list[StatLoggerBase]], PrometheusStatLogger]]: * With Local Logger, just make N copies for N EngineCores.
"""Setup logging and prometheus metrics.""" * With Prometheus, we need a single logger with N "labels"
if not log_stats:
return None
factories: list[StatLoggerFactory] This class abstracts away this implementation detail from
if custom_stat_loggers is not None: the AsyncLLM, allowing the AsyncLLM to just call .record()
factories = custom_stat_loggers and .log() to a simple interface.
else: """
factories = []
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
# engine_idx: Logger def __init__(
stat_loggers: dict[int, list[StatLoggerBase]] = {} self,
for engine_idx in engine_idxs: vllm_config: VllmConfig,
per_engine_stat_loggers: list[StatLoggerBase] = [] engine_idxs: Optional[list[int]] = None,
for logger_factory in factories: custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
per_engine_stat_loggers.append( ):
logger_factory(vllm_config, engine_idx)) self.engine_idxs = engine_idxs if engine_idxs else [0]
stat_loggers[engine_idx] = per_engine_stat_loggers
prom_stat_logger = PrometheusStatLogger(vllm_config, engine_idxs) factories: list[StatLoggerFactory]
return stat_loggers, prom_stat_logger if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = []
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
loggers.append(logger_factory(vllm_config, engine_idx))
self.per_engine_logger_dict[engine_idx] = loggers
# For Prometheus, need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = PrometheusStatLogger(vllm_config, engine_idxs)
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: Optional[int] = None,
):
if engine_idx is None:
engine_idx = 0
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
def log_engine_initialized(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()