[Misc][DP] support customized global logger for dp

Signed-off-by: Lu Fang <fanglu@fb.com>

fix the test

Signed-off-by: Lu Fang <fanglu@fb.com>

address comments

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lu Fang 2025-09-05 17:32:21 -07:00
parent 064cac7bb7
commit a46e279909
4 changed files with 185 additions and 18 deletions

View File

@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
from typing import Optional
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
"""
To run this example, run the following commands simultaneously with
@ -22,37 +25,72 @@ send a request to the instance with DP rank 1.
"""
def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass
async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
)
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
return LoggingStatLogger(config, rank)
def global_logger_factory(
config: VllmConfig, engine_idxs: Optional[list[int]]
) -> GlobalStatLogger:
return GlobalStatLogger(config, engine_indexes=engine_idxs)
engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
stat_loggers=[per_engine_logger_factory],
stat_logger_global=global_logger_factory,
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
stop_logging_event.set()
logging_thread.join()
if __name__ == "__main__":

View File

@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock()
class MockGlobalStatLogger(GlobalStatLogger):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, engine_indexes)
self.log = MagicMock()
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
@ -415,6 +424,36 @@ async def test_customize_loggers(monkeypatch):
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio
async def test_customize_global_loggers(monkeypatch):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
stat_logger_global=MockGlobalStatLogger,
)
after.callback(engine.shutdown)
await engine.do_log_stats()
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
global_logger = engine.logger_manager.global_logger
assert global_logger is not None
global_logger.log.assert_called_once()
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:

View File

@ -42,7 +42,8 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory,
StatLoggerFactory, StatLoggerManager)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats
@ -62,6 +63,7 @@ class AsyncLLM(EngineClient):
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0,
@ -101,8 +103,11 @@ class AsyncLLM(EngineClient):
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats or (stat_loggers is not None)
if not log_stats and stat_loggers is not None:
self.log_stats = log_stats or (stat_loggers
is not None) or (stat_logger_global
is not None)
if (not log_stats and stat_loggers is not None
or stat_logger_global is not None):
logger.info(
"AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers")
@ -147,6 +152,7 @@ class AsyncLLM(EngineClient):
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
custom_stat_logger_global=stat_logger_global,
enable_default_loggers=log_stats,
client_count=client_count,
)
@ -189,6 +195,7 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
@ -209,6 +216,7 @@ class AsyncLLM(EngineClient):
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
@ -224,6 +232,7 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
@ -240,6 +249,7 @@ class AsyncLLM(EngineClient):
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
stat_logger_global=stat_logger_global,
)
def __del__(self):

View File

@ -19,6 +19,8 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]],
"GlobalStatLoggerBase"]
class StatLoggerBase(ABC):
@ -48,6 +50,15 @@ class StatLoggerBase(ABC):
pass
class GlobalStatLoggerBase(StatLoggerBase):
"""Interface for logging metrics for multiple engines."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig,
engine_indexes: Optional[list[int]]):
...
class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
@ -145,6 +156,63 @@ class LoggingStatLogger(StatLoggerBase):
self.vllm_config.cache_config.num_gpu_blocks)
class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, -1)
if engine_indexes is None:
engine_indexes = [0]
self.engine_index = -1
self.engine_indexes = engine_indexes
self.vllm_config = vllm_config
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"%s Engines Aggregated: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
len(self.engine_indexes),
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", len(self.engine_indexes),
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
@ -655,6 +723,7 @@ class StatLoggerManager:
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None,
enable_default_loggers: bool = True,
client_count: int = 1,
):
@ -688,9 +757,14 @@ class StatLoggerManager:
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers
# For Prometheus, need to share the metrics between EngineCores.
# For Prometheus or custom global logger,
# need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
self.global_logger: Optional[StatLoggerBase] = None
if custom_stat_logger_global is not None:
self.global_logger = custom_stat_logger_global(
vllm_config, self.engine_idxs)
def record(
self,
@ -707,15 +781,21 @@ class StatLoggerManager:
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
if self.global_logger is not None:
self.global_logger.record(scheduler_stats, iteration_stats,
engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
if self.global_logger is not None:
self.global_logger.log()
def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized()
if self.global_logger is not None:
self.global_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()