[V1][Metrics] Allow V1 AsyncLLM to use custom logger (#14661)

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Zijing Liu 2025-04-25 22:05:40 -07:00 committed by GitHub
parent 54271bb766
commit 53e8cf53a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 118 additions and 30 deletions

View File

@ -3,16 +3,19 @@
import asyncio
from contextlib import ExitStack
from typing import Optional
from unittest.mock import MagicMock
import pytest
from vllm import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
# Assert only the last output has the finished flag set
assert all(not out.finished for out in outputs[:-1])
assert outputs[-1].finished
class MockLoggingStatLogger(LoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
self.log = MagicMock()
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be used directly.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
@ -33,8 +32,8 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase)
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__)
@ -52,7 +51,28 @@ class AsyncLLM(EngineClient):
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> None:
"""
Create an AsyncLLM.
Args:
vllm_config: global configuration.
executor_class: an Executor impl, e.g. MultiprocExecutor.
log_stats: Whether to log stats.
usage_context: Usage context of the LLM.
mm_registry: Multi-modal registry.
use_cached_outputs: Whether to use cached outputs.
log_requests: Whether to log requests.
start_engine_loop: Whether to start the engine loop.
stat_loggers: customized stat loggers for the engine.
If not provided, default stat loggers will be used.
PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE
IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE.
Returns:
None
"""
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
@ -66,15 +86,12 @@ class AsyncLLM(EngineClient):
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats:
for i in range(vllm_config.parallel_config.data_parallel_size):
loggers: list[StatLoggerBase] = []
if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
@ -118,7 +135,7 @@ class AsyncLLM(EngineClient):
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLM":
@ -129,17 +146,12 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# FIXME(rob): refactor VllmConfig to include the StatLoggers
# include StatLogger in the Oracle decision.
if stat_loggers is not None:
raise ValueError("Custom StatLoggers are not yet supported on V1. "
"Explicitly set VLLM_USE_V1=0 to disable V1.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
@ -151,6 +163,7 @@ class AsyncLLM(EngineClient):
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
@ -166,6 +179,7 @@ class AsyncLLM(EngineClient):
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
def __del__(self):

View File

@ -10,7 +10,6 @@ import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -28,6 +27,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory
logger = init_logger(__name__)
@ -43,7 +43,7 @@ class LLMEngine:
executor_class: type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
@ -55,6 +55,11 @@ class LLMEngine:
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to LLMEngine in V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -101,14 +106,9 @@ class LLMEngine:
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
@ -121,7 +121,7 @@ class LLMEngine:
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
enable_multiprocessing: bool = False,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import time
from abc import ABC, abstractmethod
from typing import Optional
from typing import Callable, Optional
import numpy as np
import prometheus_client
@ -18,8 +19,20 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
class StatLoggerBase(ABC):
"""Interface for logging metrics.
API users may define custom loggers that implement this interface.
However, note that the `SchedulerStats` and `IterationStats` classes
are not considered stable interfaces and may change in future versions.
"""
@abstractmethod
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
...
@abstractmethod
def record(self, scheduler_stats: SchedulerStats,
@ -32,7 +45,7 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase):
def __init__(self, engine_index: int = 0):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
@ -462,3 +475,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]:
return buckets
else:
return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
def setup_default_loggers(
vllm_config: VllmConfig,
log_stats: bool,
engine_num: int,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> list[list[StatLoggerBase]]:
"""Setup logging and prometheus metrics."""
if not log_stats:
return []
factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = [PrometheusStatLogger]
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
stat_loggers: list[list[StatLoggerBase]] = []
for i in range(engine_num):
per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories:
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
stat_loggers.append(per_engine_stat_loggers)
return stat_loggers