add aggregator interface and abstract common logic

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lu Fang 2025-09-22 12:50:38 -07:00
parent a46e279909
commit 98d535eb4f
4 changed files with 90 additions and 109 deletions

View File

@ -9,7 +9,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
""" """
To run this example, run the following commands simultaneously with To run this example, run the following commands simultaneously with
@ -53,15 +53,10 @@ async def main():
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger: def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
return LoggingStatLogger(config, rank) return LoggingStatLogger(config, rank)
def global_logger_factory(
config: VllmConfig, engine_idxs: Optional[list[int]]
) -> GlobalStatLogger:
return GlobalStatLogger(config, engine_indexes=engine_idxs)
engine_client = AsyncLLMEngine.from_engine_args( engine_client = AsyncLLMEngine.from_engine_args(
engine_args, engine_args,
stat_loggers=[per_engine_logger_factory], # Example: Using both regular loggers and aggregated logger
stat_logger_global=global_logger_factory, stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
) )
stop_logging_event = threading.Event() stop_logging_event = threading.Event()
logging_thread = threading.Thread( logging_thread = threading.Thread(

View File

@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
@ -389,7 +389,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock() self.log = MagicMock()
class MockGlobalStatLogger(GlobalStatLogger): class MockAggregatedStatLogger(AggregatedStatLogger):
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
@ -425,8 +425,8 @@ async def test_customize_loggers(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_customize_global_loggers(monkeypatch): async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the loggers. """Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should If a customized logger is provided at the init, it should
be added to the default loggers. be added to the default loggers.
""" """
@ -437,8 +437,7 @@ async def test_customize_global_loggers(monkeypatch):
with set_default_torch_num_threads(1): with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args( engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS, TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger], stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
stat_logger_global=MockGlobalStatLogger,
) )
after.callback(engine.shutdown) after.callback(engine.shutdown)
@ -448,9 +447,9 @@ async def test_customize_global_loggers(monkeypatch):
assert len(stat_loggers) == 1 assert len(stat_loggers) == 1
assert len( assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
global_logger = engine.logger_manager.global_logger aggregated_loggers = engine.logger_manager.aggregated_loggers
assert global_logger is not None assert len(aggregated_loggers) == 1
global_logger.log.assert_called_once() aggregated_loggers[0].log.assert_called_once()
stat_loggers[0][0].log.assert_called_once() stat_loggers[0][0].log.assert_called_once()

View File

@ -42,8 +42,7 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory, from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
StatLoggerFactory, StatLoggerManager)
from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
@ -63,7 +62,6 @@ class AsyncLLM(EngineClient):
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1, client_count: int = 1,
client_index: int = 0, client_index: int = 0,
@ -103,11 +101,8 @@ class AsyncLLM(EngineClient):
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.log_requests = log_requests self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers self.log_stats = log_stats or (stat_loggers is not None)
is not None) or (stat_logger_global if not log_stats and stat_loggers is not None:
is not None)
if (not log_stats and stat_loggers is not None
or stat_logger_global is not None):
logger.info( logger.info(
"AsyncLLM created with log_stats=False and non-empty custom " "AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers") "logger list; enabling logging without default stat loggers")
@ -152,7 +147,6 @@ class AsyncLLM(EngineClient):
vllm_config=vllm_config, vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed, engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
custom_stat_logger_global=stat_logger_global,
enable_default_loggers=log_stats, enable_default_loggers=log_stats,
client_count=client_count, client_count=client_count,
) )
@ -195,7 +189,6 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_log_requests: bool = False, enable_log_requests: bool = False,
disable_log_stats: bool = False, disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
@ -216,7 +209,6 @@ class AsyncLLM(EngineClient):
executor_class=Executor.get_class(vllm_config), executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
log_requests=enable_log_requests, log_requests=enable_log_requests,
log_stats=not disable_log_stats, log_stats=not disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
@ -232,7 +224,6 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
) -> "AsyncLLM": ) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs.""" """Create an AsyncLLM from the EngineArgs."""
@ -249,7 +240,6 @@ class AsyncLLM(EngineClient):
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
) )
def __del__(self): def __del__(self):

View File

@ -18,9 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]], StatLoggerFactory = Union[PerEngineStatLoggerFactory,
"GlobalStatLoggerBase"] type["AggregatedStatLoggerBase"]]
class StatLoggerBase(ABC): class StatLoggerBase(ABC):
@ -50,8 +50,9 @@ class StatLoggerBase(ABC):
pass pass
class GlobalStatLoggerBase(StatLoggerBase): class AggregatedStatLoggerBase(StatLoggerBase):
"""Interface for logging metrics for multiple engines.""" """Abstract base class for loggers that
aggregates statistics across multiple engines."""
@abstractmethod @abstractmethod
def __init__(self, vllm_config: VllmConfig, def __init__(self, vllm_config: VllmConfig,
@ -72,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging() self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0 self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0 self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
@ -111,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def get_log_stats(self):
now = time.monotonic() now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput( generation_throughput = self._get_throughput(
self.num_generation_tokens, now) self.num_generation_tokens, now)
self._reset(now) self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput self.last_prompt_throughput = prompt_throughput
self.engine_is_idle = not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput))
def log(self):
self.get_log_stats()
log_fn = logger.info
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
# Format and print output. # Format and print output.
log_fn( log_fn(
"Engine %03d: " "Engine %03d: "
@ -139,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
self.engine_index, self.engine_index,
prompt_throughput, self.last_prompt_throughput,
generation_throughput, self.last_generation_throughput,
scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, self.last_scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100, self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
@ -156,37 +158,34 @@ class LoggingStatLogger(StatLoggerBase):
self.vllm_config.cache_config.num_gpu_blocks) self.vllm_config.cache_config.num_gpu_blocks)
class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase): class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
def __init__(self, def __init__(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None): engine_idxs: Optional[list[int]] = None):
super().__init__(vllm_config, -1) if engine_idxs is None:
if engine_indexes is None: engine_idxs = [0]
engine_indexes = [0] self.engine_idxs = engine_idxs
self.engine_index = -1 LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
self.engine_indexes = engine_indexes
self.vllm_config = vllm_config def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
if engine_idx not in self.engine_idxs:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
engine_idx)
def log(self): def log(self):
now = time.monotonic() self.get_log_stats()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info log_fn = logger.info
if not any( if self.engine_is_idle:
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system # Avoid log noise on an idle production system
log_fn = logger.debug log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output. # Format and print output.
log_fn( log_fn(
"%s Engines Aggregated: " "%s Engines Aggregated: "
@ -195,12 +194,12 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
"Running: %d reqs, Waiting: %d reqs, " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
len(self.engine_indexes), len(self.engine_idxs),
prompt_throughput, self.last_prompt_throughput,
generation_throughput, self.last_generation_throughput,
scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, self.last_scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100, self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
@ -209,11 +208,11 @@ class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
if self.vllm_config.cache_config.num_gpu_blocks: if self.vllm_config.cache_config.num_gpu_blocks:
logger.info( logger.info(
"%d Engines: vllm cache_config_info with initialization " "%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", len(self.engine_indexes), "after num_gpu_blocks is: %d", len(self.engine_idxs),
self.vllm_config.cache_config.num_gpu_blocks) self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(AggregatedStatLoggerBase):
_gauge_cls = prometheus_client.Gauge _gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter _counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
@ -723,7 +722,6 @@ class StatLoggerManager:
vllm_config: VllmConfig, vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None, engine_idxs: Optional[list[int]] = None,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_default_loggers: bool = True, enable_default_loggers: bool = True,
client_count: int = 1, client_count: int = 1,
): ):
@ -743,28 +741,32 @@ class StatLoggerManager:
# engine_idx: StatLogger # engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
prometheus_factory = PrometheusStatLogger self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
aggregated_loggers_factories = set()
for engine_idx in self.engine_idxs: for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = [] loggers: list[StatLoggerBase] = []
for logger_factory in factories: for logger_factory in factories:
# If we get a custom prometheus logger, use that # If we get a custom prometheus logger or aggregated logger,
# instead. This is typically used for the ray case. # We initialize it separately with all engine idxs.
if (isinstance(logger_factory, type) # A custom prometheus logger is typically used for the ray.
and issubclass(logger_factory, PrometheusStatLogger)): if (isinstance(logger_factory, type) and issubclass(
prometheus_factory = logger_factory logger_factory, AggregatedStatLoggerBase)):
continue aggregated_loggers_factories.add(logger_factory)
loggers.append(logger_factory(vllm_config, else:
engine_idx)) # type: ignore loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers self.per_engine_logger_dict[engine_idx] = loggers
# If no custom aggregated logger is provide,
# For Prometheus or custom global logger, # we by default use PrometheusStatLogger
if not aggregated_loggers_factories:
aggregated_loggers_factories.add(PrometheusStatLogger)
# For custom aggregated logger(or default Prometheus Logger)
# need to share the metrics between EngineCores. # need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label. # Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) for aggregated_loggers_factory in aggregated_loggers_factories:
self.global_logger: Optional[StatLoggerBase] = None self.aggregated_loggers.append(
if custom_stat_logger_global is not None: aggregated_loggers_factory(vllm_config, engine_idxs))
self.global_logger = custom_stat_logger_global(
vllm_config, self.engine_idxs)
def record( def record(
self, self,
@ -778,24 +780,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx] per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers: for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx) logger.record(scheduler_stats, iteration_stats, engine_idx)
for logger in self.aggregated_loggers:
self.prometheus_logger.record(scheduler_stats, iteration_stats, logger.record(scheduler_stats, iteration_stats, engine_idx)
engine_idx)
if self.global_logger is not None:
self.global_logger.record(scheduler_stats, iteration_stats,
engine_idx)
def log(self): def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values(): for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers: for logger in per_engine_loggers:
logger.log() logger.log()
if self.global_logger is not None: for logger in self.aggregated_loggers:
self.global_logger.log() logger.log()
def log_engine_initialized(self): def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized() for agg_logger in self.aggregated_loggers:
if self.global_logger is not None: agg_logger.log_engine_initialized()
self.global_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values(): for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers: for logger in per_engine_loggers:
logger.log_engine_initialized() logger.log_engine_initialized()