diff --git a/examples/online_serving/multi_instance_data_parallel.py b/examples/online_serving/multi_instance_data_parallel.py index cb230913a422f..4d7eb969a773c 100644 --- a/examples/online_serving/multi_instance_data_parallel.py +++ b/examples/online_serving/multi_instance_data_parallel.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import threading from typing import Optional +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger """ To run this example, run the following commands simultaneously with @@ -22,37 +25,72 @@ send a request to the instance with DP rank 1. """ +def _do_background_logging(engine, interval, stop_event): + try: + while not stop_event.is_set(): + asyncio.run(engine.do_log_stats()) + stop_event.wait(interval) + except Exception as e: + print(f"vLLM background logging shutdown: {e}") + pass + + async def main(): engine_args = AsyncEngineArgs( model="ibm-research/PowerMoE-3b", data_parallel_size=2, + tensor_parallel_size=1, dtype="auto", max_model_len=2048, data_parallel_address="127.0.0.1", data_parallel_rpc_port=62300, data_parallel_size_local=1, enforce_eager=True, + enable_log_requests=True, + disable_custom_all_reduce=True, ) - engine_client = AsyncLLMEngine.from_engine_args(engine_args) + def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger: + return LoggingStatLogger(config, rank) + def global_logger_factory( + config: VllmConfig, engine_idxs: Optional[list[int]] + ) -> GlobalStatLogger: + return GlobalStatLogger(config, engine_indexes=engine_idxs) + + engine_client = AsyncLLMEngine.from_engine_args( + engine_args, + stat_loggers=[per_engine_logger_factory], + stat_logger_global=global_logger_factory, + ) + stop_logging_event = threading.Event() + logging_thread = threading.Thread( + target=_do_background_logging, + args=(engine_client, 5, stop_logging_event), + daemon=True, + ) + logging_thread.start() sampling_params = SamplingParams( temperature=0.7, top_p=0.9, max_tokens=100, ) + num_prompts = 10 + for i in range(num_prompts): + prompt = "Who won the 2004 World Series?" + final_output: Optional[RequestOutput] = None + async for output in engine_client.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=f"abcdef-{i}", + data_parallel_rank=1, + ): + final_output = output + if final_output: + print(final_output.outputs[0].text) - prompt = "Who won the 2004 World Series?" - final_output: Optional[RequestOutput] = None - async for output in engine_client.generate( - prompt=prompt, - sampling_params=sampling_params, - request_id="abcdef", - data_parallel_rank=1, - ): - final_output = output - if final_output: - print(final_output.outputs[0].text) + stop_logging_event.set() + logging_thread.join() if __name__ == "__main__": diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index aca546600d0b5..868d0b8aa4567 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -18,7 +18,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.metrics.loggers import LoggingStatLogger +from vllm.v1.metrics.loggers import GlobalStatLogger, LoggingStatLogger if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger): self.log = MagicMock() +class MockGlobalStatLogger(GlobalStatLogger): + + def __init__(self, + vllm_config: VllmConfig, + engine_indexes: Optional[list[int]] = None): + super().__init__(vllm_config, engine_indexes) + self.log = MagicMock() + + @pytest.mark.asyncio async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. @@ -415,6 +424,36 @@ async def test_customize_loggers(monkeypatch): stat_loggers[0][0].log.assert_called_once() +@pytest.mark.asyncio +async def test_customize_global_loggers(monkeypatch): + """Test that we can customize the loggers. + If a customized logger is provided at the init, it should + be added to the default loggers. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + stat_logger_global=MockGlobalStatLogger, + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + stat_loggers = engine.logger_manager.per_engine_logger_dict + assert len(stat_loggers) == 1 + assert len( + stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger + global_logger = engine.logger_manager.global_logger + assert global_logger is not None + global_logger.log.assert_called_once() + stat_loggers[0][0].log.assert_called_once() + + @pytest.mark.asyncio(scope="module") async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m, ExitStack() as after: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 73165c7e4c0ad..5c8d49da3fe1e 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -42,7 +42,8 @@ from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager +from vllm.v1.metrics.loggers import (GlobalStatLoggerFactory, + StatLoggerFactory, StatLoggerManager) from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats @@ -62,6 +63,7 @@ class AsyncLLM(EngineClient): log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_logger_global: Optional[GlobalStatLoggerFactory] = None, client_addresses: Optional[dict[str, str]] = None, client_count: int = 1, client_index: int = 0, @@ -101,8 +103,11 @@ class AsyncLLM(EngineClient): self.observability_config = vllm_config.observability_config self.log_requests = log_requests - self.log_stats = log_stats or (stat_loggers is not None) - if not log_stats and stat_loggers is not None: + self.log_stats = log_stats or (stat_loggers + is not None) or (stat_logger_global + is not None) + if (not log_stats and stat_loggers is not None + or stat_logger_global is not None): logger.info( "AsyncLLM created with log_stats=False and non-empty custom " "logger list; enabling logging without default stat loggers") @@ -147,6 +152,7 @@ class AsyncLLM(EngineClient): vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, + custom_stat_logger_global=stat_logger_global, enable_default_loggers=log_stats, client_count=client_count, ) @@ -189,6 +195,7 @@ class AsyncLLM(EngineClient): start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_logger_global: Optional[GlobalStatLoggerFactory] = None, enable_log_requests: bool = False, disable_log_stats: bool = False, client_addresses: Optional[dict[str, str]] = None, @@ -209,6 +216,7 @@ class AsyncLLM(EngineClient): executor_class=Executor.get_class(vllm_config), start_engine_loop=start_engine_loop, stat_loggers=stat_loggers, + stat_logger_global=stat_logger_global, log_requests=enable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, @@ -224,6 +232,7 @@ class AsyncLLM(EngineClient): start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_logger_global: Optional[GlobalStatLoggerFactory] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -240,6 +249,7 @@ class AsyncLLM(EngineClient): start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers, + stat_logger_global=stat_logger_global, ) def __del__(self): diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index b30036a6f8e80..565a0ac467df0 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -19,6 +19,8 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] +GlobalStatLoggerFactory = Callable[[VllmConfig, list[int]], + "GlobalStatLoggerBase"] class StatLoggerBase(ABC): @@ -48,6 +50,15 @@ class StatLoggerBase(ABC): pass +class GlobalStatLoggerBase(StatLoggerBase): + """Interface for logging metrics for multiple engines.""" + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, + engine_indexes: Optional[list[int]]): + ... + + class LoggingStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): @@ -145,6 +156,63 @@ class LoggingStatLogger(StatLoggerBase): self.vllm_config.cache_config.num_gpu_blocks) +class GlobalStatLogger(LoggingStatLogger, GlobalStatLoggerBase): + + def __init__(self, + vllm_config: VllmConfig, + engine_indexes: Optional[list[int]] = None): + super().__init__(vllm_config, -1) + if engine_indexes is None: + engine_indexes = [0] + self.engine_index = -1 + self.engine_indexes = engine_indexes + self.vllm_config = vllm_config + + def log(self): + now = time.monotonic() + prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) + generation_throughput = self._get_throughput( + self.num_generation_tokens, now) + + self._reset(now) + + scheduler_stats = self.last_scheduler_stats + + log_fn = logger.info + if not any( + (prompt_throughput, generation_throughput, + self.last_prompt_throughput, self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + self.last_generation_throughput = generation_throughput + self.last_prompt_throughput = prompt_throughput + + # Format and print output. + log_fn( + "%s Engines Aggregated: " + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Waiting: %d reqs, " + "GPU KV cache usage: %.1f%%, " + "Prefix cache hit rate: %.1f%%", + len(self.engine_indexes), + prompt_throughput, + generation_throughput, + scheduler_stats.num_running_reqs, + scheduler_stats.num_waiting_reqs, + scheduler_stats.kv_cache_usage * 100, + self.prefix_caching_metrics.hit_rate * 100, + ) + self.spec_decoding_logging.log(log_fn=log_fn) + + def log_engine_initialized(self): + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "%d Engines: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", len(self.engine_indexes), + self.vllm_config.cache_config.num_gpu_blocks) + + class PrometheusStatLogger(StatLoggerBase): _gauge_cls = prometheus_client.Gauge _counter_cls = prometheus_client.Counter @@ -655,6 +723,7 @@ class StatLoggerManager: vllm_config: VllmConfig, engine_idxs: Optional[list[int]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + custom_stat_logger_global: Optional[GlobalStatLoggerFactory] = None, enable_default_loggers: bool = True, client_count: int = 1, ): @@ -688,9 +757,14 @@ class StatLoggerManager: engine_idx)) # type: ignore self.per_engine_logger_dict[engine_idx] = loggers - # For Prometheus, need to share the metrics between EngineCores. + # For Prometheus or custom global logger, + # need to share the metrics between EngineCores. # Each EngineCore's metrics are expressed as a unique label. self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) + self.global_logger: Optional[StatLoggerBase] = None + if custom_stat_logger_global is not None: + self.global_logger = custom_stat_logger_global( + vllm_config, self.engine_idxs) def record( self, @@ -707,15 +781,21 @@ class StatLoggerManager: self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) + if self.global_logger is not None: + self.global_logger.record(scheduler_stats, iteration_stats, + engine_idx) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): for logger in per_engine_loggers: logger.log() + if self.global_logger is not None: + self.global_logger.log() def log_engine_initialized(self): self.prometheus_logger.log_engine_initialized() - + if self.global_logger is not None: + self.global_logger.log_engine_initialized() for per_engine_loggers in self.per_engine_logger_dict.values(): for logger in per_engine_loggers: logger.log_engine_initialized()