diff --git a/examples/online_serving/multi_instance_data_parallel.py b/examples/online_serving/multi_instance_data_parallel.py index b46cea5619671..04d21e0489402 100644 --- a/examples/online_serving/multi_instance_data_parallel.py +++ b/examples/online_serving/multi_instance_data_parallel.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import threading from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger """ To run this example, run the following commands simultaneously with @@ -21,37 +23,64 @@ send a request to the instance with DP rank 1. """ +def _do_background_logging(engine, interval, stop_event): + try: + while not stop_event.is_set(): + asyncio.run(engine.do_log_stats()) + stop_event.wait(interval) + except Exception as e: + print(f"vLLM background logging shutdown: {e}") + pass + + async def main(): engine_args = AsyncEngineArgs( model="ibm-research/PowerMoE-3b", data_parallel_size=2, + tensor_parallel_size=1, dtype="auto", max_model_len=2048, data_parallel_address="127.0.0.1", data_parallel_rpc_port=62300, data_parallel_size_local=1, enforce_eager=True, + enable_log_requests=True, + disable_custom_all_reduce=True, ) - engine_client = AsyncLLMEngine.from_engine_args(engine_args) - + engine_client = AsyncLLMEngine.from_engine_args( + engine_args, + # Example: Using aggregated logger + stat_loggers=[AggregatedLoggingStatLogger], + ) + stop_logging_event = threading.Event() + logging_thread = threading.Thread( + target=_do_background_logging, + args=(engine_client, 5, stop_logging_event), + daemon=True, + ) + logging_thread.start() sampling_params = SamplingParams( temperature=0.7, top_p=0.9, max_tokens=100, ) + num_prompts = 10 + for i in range(num_prompts): + prompt = "Who won the 2004 World Series?" + final_output: RequestOutput | None = None + async for output in engine_client.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=f"abcdef-{i}", + data_parallel_rank=1, + ): + final_output = output + if final_output: + print(final_output.outputs[0].text) - prompt = "Who won the 2004 World Series?" - final_output: RequestOutput | None = None - async for output in engine_client.generate( - prompt=prompt, - sampling_params=sampling_params, - request_id="abcdef", - data_parallel_rank=1, - ): - final_output = output - if final_output: - print(final_output.outputs[0].text) + stop_logging_event.set() + logging_thread.join() if __name__ == "__main__": diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 8f715c085b5d1..b9fa553142781 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -17,7 +17,12 @@ from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.metrics.loggers import LoggingStatLogger +from vllm.v1.metrics.loggers import ( + AggregatedLoggingStatLogger, + LoggingStatLogger, + PerEngineStatLoggerAdapter, + PrometheusStatLogger, +) if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) @@ -384,6 +389,12 @@ class MockLoggingStatLogger(LoggingStatLogger): self.log = MagicMock() +class MockAggregatedStatLogger(AggregatedLoggingStatLogger): + def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): + super().__init__(vllm_config, engine_indexes) + self.log = MagicMock() + + @pytest.mark.asyncio async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. @@ -401,10 +412,45 @@ async def test_customize_loggers(monkeypatch): await engine.do_log_stats() - stat_loggers = engine.logger_manager.per_engine_logger_dict - assert len(stat_loggers) == 1 - assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger - stat_loggers[0][0].log.assert_called_once() + stat_loggers = engine.logger_manager.stat_loggers + assert ( + len(stat_loggers) == 3 + ) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger + print(f"{stat_loggers=}") + stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once() + assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter) + assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger) + assert isinstance(stat_loggers[2], PrometheusStatLogger) + + +@pytest.mark.asyncio +async def test_customize_aggregated_loggers(monkeypatch): + """Test that we can customize the aggregated loggers. + If a customized logger is provided at the init, it should + be added to the default loggers. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + stat_loggers = engine.logger_manager.stat_loggers + assert len(stat_loggers) == 4 + # MockLoggingStatLogger + MockAggregatedStatLogger + # + LoggingStatLogger + PrometheusStatLogger + stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once() + stat_loggers[1].log.assert_called_once() + assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter) + assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger) + assert isinstance(stat_loggers[3], PrometheusStatLogger) @pytest.mark.asyncio(scope="module") diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py index bf780b1f36adf..6dd5b2b069c09 100644 --- a/tests/v1/metrics/test_engine_logger_apis.py +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -54,7 +54,7 @@ async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args): engine = AsyncLLM.from_engine_args( log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger] ) - assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger) + assert isinstance(engine.logger_manager.stat_loggers[0], RayPrometheusStatLogger) engine.shutdown() @@ -73,9 +73,11 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): disabled_log_engine_args, stat_loggers=[DummyStatLogger] ) - assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert len(engine.logger_manager.stat_loggers) == 2 + assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1 assert isinstance( - engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger + engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0], + DummyStatLogger, ) # log_stats is still True, since custom stat loggers are used diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 54a0539f40479..f0eb3c2213384 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -410,6 +410,7 @@ class EngineArgs: max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False + aggregate_engine_logging: bool = False revision: str | None = ModelConfig.revision code_revision: str | None = ModelConfig.code_revision rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") @@ -1043,6 +1044,12 @@ class EngineArgs: help="Disable logging statistics.", ) + parser.add_argument( + "--aggregate-engine-logging", + action="store_true", + help="Log aggregate rather than per-engine statistics " + "when using data parallelism.", + ) return parser @classmethod diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 96a0947c4bd31..ec5632523fe3c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -239,6 +239,7 @@ async def build_async_engine_client_from_engine_args( vllm_config=vllm_config, usage_context=usage_context, enable_log_requests=engine_args.enable_log_requests, + aggregate_engine_logging=engine_args.aggregate_engine_logging, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, client_count=client_count, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index fbbe15b7b04f2..39cd1d97c280a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -57,6 +57,7 @@ class AsyncLLM(EngineClient): log_requests: bool = True, start_engine_loop: bool = True, stat_loggers: list[StatLoggerFactory] | None = None, + aggregate_engine_logging: bool = False, client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, @@ -144,6 +145,7 @@ class AsyncLLM(EngineClient): custom_stat_loggers=stat_loggers, enable_default_loggers=log_stats, client_count=client_count, + aggregate_engine_logging=aggregate_engine_logging, ) self.logger_manager.log_engine_initialized() @@ -187,6 +189,7 @@ class AsyncLLM(EngineClient): usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: list[StatLoggerFactory] | None = None, enable_log_requests: bool = False, + aggregate_engine_logging: bool = False, disable_log_stats: bool = False, client_addresses: dict[str, str] | None = None, client_count: int = 1, @@ -209,6 +212,7 @@ class AsyncLLM(EngineClient): stat_loggers=stat_loggers, log_requests=enable_log_requests, log_stats=not disable_log_stats, + aggregate_engine_logging=aggregate_engine_logging, usage_context=usage_context, client_addresses=client_addresses, client_count=client_count, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index debf8a2192548..538fb6a04bd7b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -51,6 +51,7 @@ class LLMEngine: vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + aggregate_engine_logging: bool = False, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: list[StatLoggerFactory] | None = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, @@ -132,6 +133,7 @@ class LLMEngine: vllm_config=vllm_config, custom_stat_loggers=stat_loggers, enable_default_loggers=log_stats, + aggregate_engine_logging=aggregate_engine_logging, ) self.logger_manager.log_engine_initialized() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 8c5abae2ae652..1a8fefdd1ddf8 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -24,7 +24,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] +PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] +AggregateStatLoggerFactory = type["AggregateStatLoggerBase"] +StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory class StatLoggerBase(ABC): @@ -54,6 +56,14 @@ class StatLoggerBase(ABC): pass +class AggregateStatLoggerBase(StatLoggerBase): + """Abstract base class for loggers that + aggregate across multiple DP engines.""" + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ... + + class LoggingStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index @@ -72,6 +82,8 @@ class LoggingStatLogger(StatLoggerBase): self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 + self.engine_is_idle = False + self.aggregated = False def _reset(self, now): self.last_log_time = now @@ -92,6 +104,10 @@ class LoggingStatLogger(StatLoggerBase): return 0.0 return float(tracked_stats / delta_time) + @property + def log_prefix(self): + return "Engine {:03d}: ".format(self.engine_index) + def record( self, scheduler_stats: SchedulerStats | None, @@ -110,34 +126,37 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) - self.last_scheduler_stats = scheduler_stats - + if not self.aggregated: + self.last_scheduler_stats = scheduler_stats if mm_cache_stats: self.mm_caching_metrics.observe(mm_cache_stats) - def log(self): + def _update_stats(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) generation_throughput = self._get_throughput(self.num_generation_tokens, now) self._reset(now) - - scheduler_stats = self.last_scheduler_stats - - log_fn = logger.info - if not any( + self.engine_is_idle = not any( ( prompt_throughput, generation_throughput, self.last_prompt_throughput, self.last_generation_throughput, ) - ): - # Avoid log noise on an idle production system - log_fn = logger.debug + ) self.last_generation_throughput = generation_throughput self.last_prompt_throughput = prompt_throughput + def aggregate_scheduler_stats(self): + # noop for per engine loggers + return + + def log(self): + self._update_stats() + self.aggregate_scheduler_stats() + # Avoid log noise on an idle production system + log_fn = logger.debug if self.engine_is_idle else logger.info # Format and print output. log_parts = [ "Avg prompt throughput: %.1f tokens/s", @@ -148,11 +167,11 @@ class LoggingStatLogger(StatLoggerBase): "Prefix cache hit rate: %.1f%%", ] log_args = [ - prompt_throughput, - generation_throughput, - scheduler_stats.num_running_reqs, - scheduler_stats.num_waiting_reqs, - scheduler_stats.kv_cache_usage * 100, + self.last_prompt_throughput, + self.last_generation_throughput, + self.last_scheduler_stats.num_running_reqs, + self.last_scheduler_stats.num_waiting_reqs, + self.last_scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, ] if not self.mm_caching_metrics.empty: @@ -160,8 +179,7 @@ class LoggingStatLogger(StatLoggerBase): log_args.append(self.mm_caching_metrics.hit_rate * 100) log_fn( - "Engine %03d: " + ", ".join(log_parts), - self.engine_index, + self.log_prefix + ", ".join(log_parts), *log_args, ) @@ -178,7 +196,114 @@ class LoggingStatLogger(StatLoggerBase): ) -class PrometheusStatLogger(StatLoggerBase): +class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase): + def __init__( + self, + vllm_config: VllmConfig, + engine_indexes: list[int], + ): + self.engine_indexes = engine_indexes + self.last_scheduler_stats_dict: dict[int, SchedulerStats] = { + idx: SchedulerStats() for idx in self.engine_indexes + } + LoggingStatLogger.__init__(self, vllm_config, engine_index=-1) + self.aggregated = True + + @property + def log_prefix(self): + return "{} Engines Aggregated: ".format(len(self.engine_indexes)) + + def record( + self, + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, + engine_idx: int = 0, + ): + if engine_idx not in self.engine_indexes: + logger.warning("Unexpected engine_idx: %d", engine_idx) + return + LoggingStatLogger.record( + self, + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) + if scheduler_stats is not None: + self.last_scheduler_stats_dict[engine_idx] = scheduler_stats + + def aggregate_scheduler_stats(self): + self.last_scheduler_stats = SchedulerStats() + for last_scheduler_stats in self.last_scheduler_stats_dict.values(): + self.last_scheduler_stats.num_waiting_reqs += ( + last_scheduler_stats.num_waiting_reqs + ) + self.last_scheduler_stats.num_running_reqs += ( + last_scheduler_stats.num_running_reqs + ) + self.last_scheduler_stats.num_corrupted_reqs += ( + last_scheduler_stats.num_corrupted_reqs + ) + self.last_scheduler_stats.kv_cache_usage += ( + last_scheduler_stats.kv_cache_usage + ) + self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict) + + def log(self): + LoggingStatLogger.log(self) + + def log_engine_initialized(self): + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "%d Engines: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", + len(self.engine_indexes), + self.vllm_config.cache_config.num_gpu_blocks, + ) + + +class PerEngineStatLoggerAdapter(AggregateStatLoggerBase): + def __init__( + self, + vllm_config: VllmConfig, + engine_indexes: list[int], + per_engine_stat_logger_factory: PerEngineStatLoggerFactory, + ) -> None: + self.per_engine_stat_loggers = {} + self.engine_indexes = engine_indexes + for engine_index in engine_indexes: + self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory( + vllm_config, engine_index + ) + + def record( + self, + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, + engine_idx: int = 0, + ): + if engine_idx not in self.per_engine_stat_loggers: + logger.warning("Unexpected engine_idx: %d", engine_idx) + return + self.per_engine_stat_loggers[engine_idx].record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) + + def log(self): + for per_engine_stat_logger in self.per_engine_stat_loggers.values(): + per_engine_stat_logger.log() + + def log_engine_initialized(self): + for per_engine_stat_logger in self.per_engine_stat_loggers.values(): + per_engine_stat_logger.log_engine_initialized() + + +class PrometheusStatLogger(AggregateStatLoggerBase): _gauge_cls = Gauge _counter_cls = Counter _histogram_cls = Histogram @@ -189,6 +314,7 @@ class PrometheusStatLogger(StatLoggerBase): ): if engine_indexes is None: engine_indexes = [0] + self.engine_indexes = engine_indexes unregister_vllm_metrics() @@ -880,14 +1006,14 @@ class StatLoggerManager: engine_idxs: list[int] | None = None, custom_stat_loggers: list[StatLoggerFactory] | None = None, enable_default_loggers: bool = True, + aggregate_engine_logging: bool = False, client_count: int = 1, ): - self.engine_idxs = engine_idxs if engine_idxs else [0] - - factories: list[StatLoggerFactory] = [] + self.engine_indexes = engine_idxs if engine_idxs else [0] + self.stat_loggers: list[AggregateStatLoggerBase] = [] + stat_logger_factories: list[StatLoggerFactory] = [] if custom_stat_loggers is not None: - factories.extend(custom_stat_loggers) - + stat_logger_factories.extend(custom_stat_loggers) if enable_default_loggers and logger.isEnabledFor(logging.INFO): if client_count > 1: logger.warning( @@ -895,27 +1021,35 @@ class StatLoggerManager: "disabling stats logging to avoid incomplete stats." ) else: - factories.append(LoggingStatLogger) - - # engine_idx: StatLogger - self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} - prometheus_factory = PrometheusStatLogger - for engine_idx in self.engine_idxs: - loggers: list[StatLoggerBase] = [] - for logger_factory in factories: - # If we get a custom prometheus logger, use that - # instead. This is typically used for the ray case. - if isinstance(logger_factory, type) and issubclass( - logger_factory, PrometheusStatLogger - ): - prometheus_factory = logger_factory - continue - loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore - self.per_engine_logger_dict[engine_idx] = loggers - - # For Prometheus, need to share the metrics between EngineCores. - # Each EngineCore's metrics are expressed as a unique label. - self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) + default_logger_factory = ( + AggregatedLoggingStatLogger + if aggregate_engine_logging + else LoggingStatLogger + ) + stat_logger_factories.append(default_logger_factory) + custom_prometheus_logger: bool = False + for stat_logger_factory in stat_logger_factories: + if isinstance(stat_logger_factory, type) and issubclass( + stat_logger_factory, AggregateStatLoggerBase + ): + global_stat_logger = stat_logger_factory( + vllm_config=vllm_config, + engine_indexes=self.engine_indexes, + ) + if isinstance(global_stat_logger, PrometheusStatLogger): + custom_prometheus_logger = True + else: + # per engine logger + global_stat_logger = PerEngineStatLoggerAdapter( + vllm_config=vllm_config, + engine_indexes=self.engine_indexes, + per_engine_stat_logger_factory=stat_logger_factory, # type: ignore[arg-type] + ) + self.stat_loggers.append(global_stat_logger) + if not custom_prometheus_logger: + self.stat_loggers.append( + PrometheusStatLogger(vllm_config, self.engine_indexes) + ) def record( self, @@ -926,9 +1060,7 @@ class StatLoggerManager: ): if engine_idx is None: engine_idx = 0 - - per_engine_loggers = self.per_engine_logger_dict[engine_idx] - for logger in per_engine_loggers: + for logger in self.stat_loggers: logger.record( scheduler_stats, iteration_stats, @@ -936,21 +1068,10 @@ class StatLoggerManager: engine_idx=engine_idx, ) - self.prometheus_logger.record( - scheduler_stats, - iteration_stats, - mm_cache_stats=mm_cache_stats, - engine_idx=engine_idx, - ) - def log(self): - for per_engine_loggers in self.per_engine_logger_dict.values(): - for logger in per_engine_loggers: - logger.log() + for logger in self.stat_loggers: + logger.log() def log_engine_initialized(self): - self.prometheus_logger.log_engine_initialized() - - for per_engine_loggers in self.per_engine_logger_dict.values(): - for logger in per_engine_loggers: - logger.log_engine_initialized() + for agg_logger in self.stat_loggers: + agg_logger.log_engine_initialized()