[Misc][DP] support customized aggregated logger for dp (#24354)

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-10-13 17:45:59 -07:00 committed by GitHub
parent d8bebb008a
commit 8317f72354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 297 additions and 85 deletions

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import threading
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger
""" """
To run this example, run the following commands simultaneously with To run this example, run the following commands simultaneously with
@ -21,37 +23,64 @@ send a request to the instance with DP rank 1.
""" """
def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass
async def main(): async def main():
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b", model="ibm-research/PowerMoE-3b",
data_parallel_size=2, data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto", dtype="auto",
max_model_len=2048, max_model_len=2048,
data_parallel_address="127.0.0.1", data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300, data_parallel_rpc_port=62300,
data_parallel_size_local=1, data_parallel_size_local=1,
enforce_eager=True, enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
) )
engine_client = AsyncLLMEngine.from_engine_args(engine_args) engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
# Example: Using aggregated logger
stat_loggers=[AggregatedLoggingStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.7, temperature=0.7,
top_p=0.9, top_p=0.9,
max_tokens=100, max_tokens=100,
) )
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: RequestOutput | None = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
prompt = "Who won the 2004 World Series?" stop_logging_event.set()
final_output: RequestOutput | None = None logging_thread.join()
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -17,7 +17,12 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger from vllm.v1.metrics.loggers import (
AggregatedLoggingStatLogger,
LoggingStatLogger,
PerEngineStatLoggerAdapter,
PrometheusStatLogger,
)
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
@ -384,6 +389,12 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock() self.log = MagicMock()
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
super().__init__(vllm_config, engine_indexes)
self.log = MagicMock()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_customize_loggers(monkeypatch): async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers. """Test that we can customize the loggers.
@ -401,10 +412,45 @@ async def test_customize_loggers(monkeypatch):
await engine.do_log_stats() await engine.do_log_stats()
stat_loggers = engine.logger_manager.per_engine_logger_dict stat_loggers = engine.logger_manager.stat_loggers
assert len(stat_loggers) == 1 assert (
assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger len(stat_loggers) == 3
stat_loggers[0][0].log.assert_called_once() ) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger
print(f"{stat_loggers=}")
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
assert isinstance(stat_loggers[2], PrometheusStatLogger)
@pytest.mark.asyncio
async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
stat_loggers = engine.logger_manager.stat_loggers
assert len(stat_loggers) == 4
# MockLoggingStatLogger + MockAggregatedStatLogger
# + LoggingStatLogger + PrometheusStatLogger
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
stat_loggers[1].log.assert_called_once()
assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
assert isinstance(stat_loggers[3], PrometheusStatLogger)
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")

View File

@ -54,7 +54,7 @@ async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args):
engine = AsyncLLM.from_engine_args( engine = AsyncLLM.from_engine_args(
log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger] log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger]
) )
assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger) assert isinstance(engine.logger_manager.stat_loggers[0], RayPrometheusStatLogger)
engine.shutdown() engine.shutdown()
@ -73,9 +73,11 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args):
disabled_log_engine_args, stat_loggers=[DummyStatLogger] disabled_log_engine_args, stat_loggers=[DummyStatLogger]
) )
assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 assert len(engine.logger_manager.stat_loggers) == 2
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
assert isinstance( assert isinstance(
engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
DummyStatLogger,
) )
# log_stats is still True, since custom stat loggers are used # log_stats is still True, since custom stat loggers are used

View File

@ -410,6 +410,7 @@ class EngineArgs:
max_logprobs: int = ModelConfig.max_logprobs max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False disable_log_stats: bool = False
aggregate_engine_logging: bool = False
revision: str | None = ModelConfig.revision revision: str | None = ModelConfig.revision
code_revision: str | None = ModelConfig.code_revision code_revision: str | None = ModelConfig.code_revision
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
@ -1043,6 +1044,12 @@ class EngineArgs:
help="Disable logging statistics.", help="Disable logging statistics.",
) )
parser.add_argument(
"--aggregate-engine-logging",
action="store_true",
help="Log aggregate rather than per-engine statistics "
"when using data parallelism.",
)
return parser return parser
@classmethod @classmethod

View File

@ -239,6 +239,7 @@ async def build_async_engine_client_from_engine_args(
vllm_config=vllm_config, vllm_config=vllm_config,
usage_context=usage_context, usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests, enable_log_requests=engine_args.enable_log_requests,
aggregate_engine_logging=engine_args.aggregate_engine_logging,
disable_log_stats=engine_args.disable_log_stats, disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config, client_addresses=client_config,
client_count=client_count, client_count=client_count,

View File

@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
stat_loggers: list[StatLoggerFactory] | None = None, stat_loggers: list[StatLoggerFactory] | None = None,
aggregate_engine_logging: bool = False,
client_addresses: dict[str, str] | None = None, client_addresses: dict[str, str] | None = None,
client_count: int = 1, client_count: int = 1,
client_index: int = 0, client_index: int = 0,
@ -144,6 +145,7 @@ class AsyncLLM(EngineClient):
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats, enable_default_loggers=log_stats,
client_count=client_count, client_count=client_count,
aggregate_engine_logging=aggregate_engine_logging,
) )
self.logger_manager.log_engine_initialized() self.logger_manager.log_engine_initialized()
@ -187,6 +189,7 @@ class AsyncLLM(EngineClient):
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None, stat_loggers: list[StatLoggerFactory] | None = None,
enable_log_requests: bool = False, enable_log_requests: bool = False,
aggregate_engine_logging: bool = False,
disable_log_stats: bool = False, disable_log_stats: bool = False,
client_addresses: dict[str, str] | None = None, client_addresses: dict[str, str] | None = None,
client_count: int = 1, client_count: int = 1,
@ -209,6 +212,7 @@ class AsyncLLM(EngineClient):
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
log_requests=enable_log_requests, log_requests=enable_log_requests,
log_stats=not disable_log_stats, log_stats=not disable_log_stats,
aggregate_engine_logging=aggregate_engine_logging,
usage_context=usage_context, usage_context=usage_context,
client_addresses=client_addresses, client_addresses=client_addresses,
client_count=client_count, client_count=client_count,

View File

@ -51,6 +51,7 @@ class LLMEngine:
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
aggregate_engine_logging: bool = False,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None, stat_loggers: list[StatLoggerFactory] | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
@ -132,6 +133,7 @@ class LLMEngine:
vllm_config=vllm_config, vllm_config=vllm_config,
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats, enable_default_loggers=log_stats,
aggregate_engine_logging=aggregate_engine_logging,
) )
self.logger_manager.log_engine_initialized() self.logger_manager.log_engine_initialized()

View File

@ -24,7 +24,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
AggregateStatLoggerFactory = type["AggregateStatLoggerBase"]
StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory
class StatLoggerBase(ABC): class StatLoggerBase(ABC):
@ -54,6 +56,14 @@ class StatLoggerBase(ABC):
pass pass
class AggregateStatLoggerBase(StatLoggerBase):
"""Abstract base class for loggers that
aggregate across multiple DP engines."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ...
class LoggingStatLogger(StatLoggerBase): class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.engine_index = engine_index self.engine_index = engine_index
@ -72,6 +82,8 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
self.last_prompt_throughput: float = 0.0 self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0 self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
self.aggregated = False
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
@ -92,6 +104,10 @@ class LoggingStatLogger(StatLoggerBase):
return 0.0 return 0.0
return float(tracked_stats / delta_time) return float(tracked_stats / delta_time)
@property
def log_prefix(self):
return "Engine {:03d}: ".format(self.engine_index)
def record( def record(
self, self,
scheduler_stats: SchedulerStats | None, scheduler_stats: SchedulerStats | None,
@ -110,34 +126,37 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats: if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_connector_logging.observe(kv_connector_stats) self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats if not self.aggregated:
self.last_scheduler_stats = scheduler_stats
if mm_cache_stats: if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats) self.mm_caching_metrics.observe(mm_cache_stats)
def log(self): def _update_stats(self):
now = time.monotonic() now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(self.num_generation_tokens, now) generation_throughput = self._get_throughput(self.num_generation_tokens, now)
self._reset(now) self._reset(now)
self.engine_is_idle = not any(
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
( (
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
self.last_prompt_throughput, self.last_prompt_throughput,
self.last_generation_throughput, self.last_generation_throughput,
) )
): )
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput self.last_prompt_throughput = prompt_throughput
def aggregate_scheduler_stats(self):
# noop for per engine loggers
return
def log(self):
self._update_stats()
self.aggregate_scheduler_stats()
# Avoid log noise on an idle production system
log_fn = logger.debug if self.engine_is_idle else logger.info
# Format and print output. # Format and print output.
log_parts = [ log_parts = [
"Avg prompt throughput: %.1f tokens/s", "Avg prompt throughput: %.1f tokens/s",
@ -148,11 +167,11 @@ class LoggingStatLogger(StatLoggerBase):
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
] ]
log_args = [ log_args = [
prompt_throughput, self.last_prompt_throughput,
generation_throughput, self.last_generation_throughput,
scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, self.last_scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100, self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
] ]
if not self.mm_caching_metrics.empty: if not self.mm_caching_metrics.empty:
@ -160,8 +179,7 @@ class LoggingStatLogger(StatLoggerBase):
log_args.append(self.mm_caching_metrics.hit_rate * 100) log_args.append(self.mm_caching_metrics.hit_rate * 100)
log_fn( log_fn(
"Engine %03d: " + ", ".join(log_parts), self.log_prefix + ", ".join(log_parts),
self.engine_index,
*log_args, *log_args,
) )
@ -178,7 +196,114 @@ class LoggingStatLogger(StatLoggerBase):
) )
class PrometheusStatLogger(StatLoggerBase): class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
def __init__(
self,
vllm_config: VllmConfig,
engine_indexes: list[int],
):
self.engine_indexes = engine_indexes
self.last_scheduler_stats_dict: dict[int, SchedulerStats] = {
idx: SchedulerStats() for idx in self.engine_indexes
}
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
self.aggregated = True
@property
def log_prefix(self):
return "{} Engines Aggregated: ".format(len(self.engine_indexes))
def record(
self,
scheduler_stats: SchedulerStats | None,
iteration_stats: IterationStats | None,
mm_cache_stats: MultiModalCacheStats | None = None,
engine_idx: int = 0,
):
if engine_idx not in self.engine_indexes:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
LoggingStatLogger.record(
self,
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
if scheduler_stats is not None:
self.last_scheduler_stats_dict[engine_idx] = scheduler_stats
def aggregate_scheduler_stats(self):
self.last_scheduler_stats = SchedulerStats()
for last_scheduler_stats in self.last_scheduler_stats_dict.values():
self.last_scheduler_stats.num_waiting_reqs += (
last_scheduler_stats.num_waiting_reqs
)
self.last_scheduler_stats.num_running_reqs += (
last_scheduler_stats.num_running_reqs
)
self.last_scheduler_stats.num_corrupted_reqs += (
last_scheduler_stats.num_corrupted_reqs
)
self.last_scheduler_stats.kv_cache_usage += (
last_scheduler_stats.kv_cache_usage
)
self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict)
def log(self):
LoggingStatLogger.log(self)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d",
len(self.engine_indexes),
self.vllm_config.cache_config.num_gpu_blocks,
)
class PerEngineStatLoggerAdapter(AggregateStatLoggerBase):
def __init__(
self,
vllm_config: VllmConfig,
engine_indexes: list[int],
per_engine_stat_logger_factory: PerEngineStatLoggerFactory,
) -> None:
self.per_engine_stat_loggers = {}
self.engine_indexes = engine_indexes
for engine_index in engine_indexes:
self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory(
vllm_config, engine_index
)
def record(
self,
scheduler_stats: SchedulerStats | None,
iteration_stats: IterationStats | None,
mm_cache_stats: MultiModalCacheStats | None = None,
engine_idx: int = 0,
):
if engine_idx not in self.per_engine_stat_loggers:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
self.per_engine_stat_loggers[engine_idx].record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
def log(self):
for per_engine_stat_logger in self.per_engine_stat_loggers.values():
per_engine_stat_logger.log()
def log_engine_initialized(self):
for per_engine_stat_logger in self.per_engine_stat_loggers.values():
per_engine_stat_logger.log_engine_initialized()
class PrometheusStatLogger(AggregateStatLoggerBase):
_gauge_cls = Gauge _gauge_cls = Gauge
_counter_cls = Counter _counter_cls = Counter
_histogram_cls = Histogram _histogram_cls = Histogram
@ -189,6 +314,7 @@ class PrometheusStatLogger(StatLoggerBase):
): ):
if engine_indexes is None: if engine_indexes is None:
engine_indexes = [0] engine_indexes = [0]
self.engine_indexes = engine_indexes self.engine_indexes = engine_indexes
unregister_vllm_metrics() unregister_vllm_metrics()
@ -880,14 +1006,14 @@ class StatLoggerManager:
engine_idxs: list[int] | None = None, engine_idxs: list[int] | None = None,
custom_stat_loggers: list[StatLoggerFactory] | None = None, custom_stat_loggers: list[StatLoggerFactory] | None = None,
enable_default_loggers: bool = True, enable_default_loggers: bool = True,
aggregate_engine_logging: bool = False,
client_count: int = 1, client_count: int = 1,
): ):
self.engine_idxs = engine_idxs if engine_idxs else [0] self.engine_indexes = engine_idxs if engine_idxs else [0]
self.stat_loggers: list[AggregateStatLoggerBase] = []
factories: list[StatLoggerFactory] = [] stat_logger_factories: list[StatLoggerFactory] = []
if custom_stat_loggers is not None: if custom_stat_loggers is not None:
factories.extend(custom_stat_loggers) stat_logger_factories.extend(custom_stat_loggers)
if enable_default_loggers and logger.isEnabledFor(logging.INFO): if enable_default_loggers and logger.isEnabledFor(logging.INFO):
if client_count > 1: if client_count > 1:
logger.warning( logger.warning(
@ -895,27 +1021,35 @@ class StatLoggerManager:
"disabling stats logging to avoid incomplete stats." "disabling stats logging to avoid incomplete stats."
) )
else: else:
factories.append(LoggingStatLogger) default_logger_factory = (
AggregatedLoggingStatLogger
# engine_idx: StatLogger if aggregate_engine_logging
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} else LoggingStatLogger
prometheus_factory = PrometheusStatLogger )
for engine_idx in self.engine_idxs: stat_logger_factories.append(default_logger_factory)
loggers: list[StatLoggerBase] = [] custom_prometheus_logger: bool = False
for logger_factory in factories: for stat_logger_factory in stat_logger_factories:
# If we get a custom prometheus logger, use that if isinstance(stat_logger_factory, type) and issubclass(
# instead. This is typically used for the ray case. stat_logger_factory, AggregateStatLoggerBase
if isinstance(logger_factory, type) and issubclass( ):
logger_factory, PrometheusStatLogger global_stat_logger = stat_logger_factory(
): vllm_config=vllm_config,
prometheus_factory = logger_factory engine_indexes=self.engine_indexes,
continue )
loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore if isinstance(global_stat_logger, PrometheusStatLogger):
self.per_engine_logger_dict[engine_idx] = loggers custom_prometheus_logger = True
else:
# For Prometheus, need to share the metrics between EngineCores. # per engine logger
# Each EngineCore's metrics are expressed as a unique label. global_stat_logger = PerEngineStatLoggerAdapter(
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) vllm_config=vllm_config,
engine_indexes=self.engine_indexes,
per_engine_stat_logger_factory=stat_logger_factory, # type: ignore[arg-type]
)
self.stat_loggers.append(global_stat_logger)
if not custom_prometheus_logger:
self.stat_loggers.append(
PrometheusStatLogger(vllm_config, self.engine_indexes)
)
def record( def record(
self, self,
@ -926,9 +1060,7 @@ class StatLoggerManager:
): ):
if engine_idx is None: if engine_idx is None:
engine_idx = 0 engine_idx = 0
for logger in self.stat_loggers:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record( logger.record(
scheduler_stats, scheduler_stats,
iteration_stats, iteration_stats,
@ -936,21 +1068,10 @@ class StatLoggerManager:
engine_idx=engine_idx, engine_idx=engine_idx,
) )
self.prometheus_logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
def log(self): def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values(): for logger in self.stat_loggers:
for logger in per_engine_loggers: logger.log()
logger.log()
def log_engine_initialized(self): def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized() for agg_logger in self.stat_loggers:
agg_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()