add aggregator interface and abstract common logic

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lu Fang 2025-09-22 12:50:38 -07:00
parent a46e279909
commit 98d535eb4f
4 changed files with 90 additions and 109 deletions

View File

@ -9,7 +9,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
"""
To run this example, run the following commands simultaneously with
@ -53,15 +53,10 @@ async def main():
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
return LoggingStatLogger(config, rank)
def global_logger_factory(
config: VllmConfig, engine_idxs: Optional[list[int]]
) -> GlobalStatLogger:
return GlobalStatLogger(config, engine_indexes=engine_idxs)
engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
stat_loggers=[per_engine_logger_factory],
stat_logger_global=global_logger_factory,
# Example: Using both regular loggers and aggregated logger
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(

View File

@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -389,7 +389,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock()
class MockGlobalStatLogger(GlobalStatLogger):
class MockAggregatedStatLogger(AggregatedStatLogger):
def __init__(self,
vllm_config: VllmConfig,
@ -425,8 +425,8 @@ async def test_customize_loggers(monkeypatch):
@pytest.mark.asyncio
async def test_customize_global_loggers(monkeypatch):
"""Test that we can customize the loggers.
async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
@ -437,8 +437,7 @@ async def test_customize_global_loggers(monkeypatch):
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
stat_logger_global=MockGlobalStatLogger,
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
)
after.callback(engine.shutdown)
@ -448,9 +447,9 @@ async def test_customize_global_loggers(monkeypatch):
assert len(stat_loggers) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
global_logger = engine.logger_manager.global_logger
assert global_logger is not None
global_logger.log.assert_called_once()
aggregated_loggers = engine.logger_manager.aggregated_loggers
assert len(aggregated_loggers) == 1
aggregated_loggers[0].log.assert_called_once()
stat_loggers[0][0].log.assert_called_once()

View File

@ -42,8 +42,7 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory,
StatLoggerFactory, StatLoggerManager)
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats
@ -63,7 +62,6 @@ class AsyncLLM(EngineClient):
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
@ -103,11 +101,8 @@ class AsyncLLM(EngineClient):
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers
is not None) or (stat_logger_global
is not None)
if (not log_stats and stat_loggers is not None
or stat_logger_global is not None):
self.log_stats = log_stats or (stat_loggers is not None)
if not log_stats and stat_loggers is not None:
logger.info(
"AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers")
@ -152,7 +147,6 @@ class AsyncLLM(EngineClient):
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
custom_stat_logger_global=stat_logger_global,
enable_default_loggers=log_stats,
client_count=client_count,
)
@ -195,7 +189,6 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
@ -216,7 +209,6 @@ class AsyncLLM(EngineClient):
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
@ -232,7 +224,6 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
@ -249,7 +240,6 @@ class AsyncLLM(EngineClient):
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
)
def __del__(self):

View File

@ -18,9 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]],
"GlobalStatLoggerBase"]
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
type["AggregatedStatLoggerBase"]]
class StatLoggerBase(ABC):
@ -50,8 +50,9 @@ class StatLoggerBase(ABC):
pass
class GlobalStatLoggerBase(StatLoggerBase):
"""Interface for logging metrics for multiple engines."""
class AggregatedStatLoggerBase(StatLoggerBase):
"""Abstract base class for loggers that
aggregates statistics across multiple engines."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig,
@ -72,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
def _reset(self, now):
self.last_log_time = now
@ -111,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
self.last_scheduler_stats = scheduler_stats
def log(self):
def get_log_stats(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
self.engine_is_idle = not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput))
def log(self):
self.get_log_stats()
log_fn = logger.info
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
# Format and print output.
log_fn(
"Engine %03d: "
@ -139,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
@ -156,37 +158,34 @@ class LoggingStatLogger(StatLoggerBase):
self.vllm_config.cache_config.num_gpu_blocks)
class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, -1)
if engine_indexes is None:
engine_indexes = [0]
self.engine_index = -1
self.engine_indexes = engine_indexes
self.vllm_config = vllm_config
engine_idxs: Optional[list[int]] = None):
if engine_idxs is None:
engine_idxs = [0]
self.engine_idxs = engine_idxs
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
if engine_idx not in self.engine_idxs:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
engine_idx)
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
self.get_log_stats()
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"%s Engines Aggregated: "
@ -195,12 +194,12 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
len(self.engine_indexes),
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
len(self.engine_idxs),
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
@ -209,11 +208,11 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", len(self.engine_indexes),
"after num_gpu_blocks is: %d", len(self.engine_idxs),
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
class PrometheusStatLogger(AggregatedStatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
@ -723,7 +722,6 @@ class StatLoggerManager:
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_default_loggers: bool = True,
client_count: int = 1,
):
@ -743,28 +741,32 @@ class StatLoggerManager:
# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
prometheus_factory = PrometheusStatLogger
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
aggregated_loggers_factories = set()
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
# If we get a custom prometheus logger, use that
# instead. This is typically used for the ray case.
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
prometheus_factory = logger_factory
continue
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
# If we get a custom prometheus logger or aggregated logger,
# We initialize it separately with all engine idxs.
# A custom prometheus logger is typically used for the ray.
if (isinstance(logger_factory, type) and issubclass(
logger_factory, AggregatedStatLoggerBase)):
aggregated_loggers_factories.add(logger_factory)
else:
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers
# For Prometheus or custom global logger,
# If no custom aggregated logger is provide,
# we by default use PrometheusStatLogger
if not aggregated_loggers_factories:
aggregated_loggers_factories.add(PrometheusStatLogger)
# For custom aggregated logger(or default Prometheus Logger)
# need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
self.global_logger: Optional[StatLoggerBase] = None
if custom_stat_logger_global is not None:
self.global_logger = custom_stat_logger_global(
vllm_config, self.engine_idxs)
for aggregated_loggers_factory in aggregated_loggers_factories:
self.aggregated_loggers.append(
aggregated_loggers_factory(vllm_config, engine_idxs))
def record(
self,
@ -778,24 +780,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
if self.global_logger is not None:
self.global_logger.record(scheduler_stats, iteration_stats,
engine_idx)
for logger in self.aggregated_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
if self.global_logger is not None:
self.global_logger.log()
for logger in self.aggregated_loggers:
logger.log()
def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized()
if self.global_logger is not None:
self.global_logger.log_engine_initialized()
for agg_logger in self.aggregated_loggers:
agg_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()