diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2758c59f49ee8..477a4f8a98477 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -96,10 +96,17 @@ class AsyncLLM(EngineClient): self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. + # HACK: asyncllm should not be aware of how many engines is it + # managing. + start_idx = vllm_config.parallel_config.data_parallel_rank + local_engines = vllm_config.parallel_config.data_parallel_size_local + engine_idxs = [ + idx for idx in range(start_idx, start_idx + local_engines) + ] self.stat_loggers = setup_default_loggers( vllm_config=vllm_config, log_stats=self.log_stats, - engine_num=vllm_config.parallel_config.data_parallel_size, + engine_idxs=engine_idxs, custom_stat_loggers=stat_loggers, ) @@ -130,9 +137,11 @@ class AsyncLLM(EngineClient): client_index=client_index, ) if self.stat_loggers: - per_engine_loggers, _ = self.stat_loggers - for stat_logger in per_engine_loggers[0]: - stat_logger.log_engine_initialized() + # loggers, prom_logger + loggers, _ = self.stat_loggers + for per_engine_loggers in loggers.values(): + for logger in per_engine_loggers: + logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None try: # Start output handler eagerly if we are in the asyncio eventloop. @@ -435,7 +444,8 @@ class AsyncLLM(EngineClient): @staticmethod def _record_stats( - stat_loggers: tuple[list[list[StatLoggerBase]], PrometheusStatLogger], + stat_loggers: tuple[dict[int, list[StatLoggerBase]], + PrometheusStatLogger], engine_idx: int, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], @@ -559,7 +569,7 @@ class AsyncLLM(EngineClient): if self.stat_loggers is None: return per_engine_loggers, _ = self.stat_loggers - for loggers in per_engine_loggers: + for loggers in per_engine_loggers.values(): for stat_logger in loggers: stat_logger.log() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 6bb14d9c4e83a..7a89d05fcbd3e 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -150,7 +150,11 @@ class PrometheusStatLogger(StatLoggerBase): _histogram_cls = prometheus_client.Histogram _spec_decoding_cls = SpecDecodingProm - def __init__(self, vllm_config: VllmConfig, engine_num: int = 1): + def __init__(self, + vllm_config: VllmConfig, + engine_indexes: Optional[list[int]] = None): + if engine_indexes is None: + engine_indexes = [0] unregister_vllm_metrics() self.vllm_config = vllm_config @@ -162,7 +166,6 @@ class PrometheusStatLogger(StatLoggerBase): labelnames = ["model_name", "engine"] model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - engine_indexes = list(range(engine_num)) # self.spec_decoding_prom = self._spec_decoding_cls( # vllm_config.speculative_config, labelnames, labelvalues) @@ -600,9 +603,9 @@ def build_1_2_5_buckets(max_value: int) -> list[int]: def setup_default_loggers( vllm_config: VllmConfig, log_stats: bool, - engine_num: int, + engine_idxs: list[int], custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, -) -> Optional[tuple[list[list[StatLoggerBase]], PrometheusStatLogger]]: +) -> Optional[tuple[dict[int, list[StatLoggerBase]], PrometheusStatLogger]]: """Setup logging and prometheus metrics.""" if not log_stats: return None @@ -615,13 +618,14 @@ def setup_default_loggers( if logger.isEnabledFor(logging.INFO): factories.append(LoggingStatLogger) - stat_loggers: list[list[StatLoggerBase]] = [] - for engine_idx in range(engine_num): + # engine_idx: Logger + stat_loggers: dict[int, list[StatLoggerBase]] = {} + for engine_idx in engine_idxs: per_engine_stat_loggers: list[StatLoggerBase] = [] for logger_factory in factories: per_engine_stat_loggers.append( logger_factory(vllm_config, engine_idx)) - stat_loggers.append(per_engine_stat_loggers) + stat_loggers[engine_idx] = per_engine_stat_loggers - prom_stat_logger = PrometheusStatLogger(vllm_config, engine_num) + prom_stat_logger = PrometheusStatLogger(vllm_config, engine_idxs) return stat_loggers, prom_stat_logger diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index cce692d6c09e7..73344937d89ba 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -122,8 +122,10 @@ class RayPrometheusStatLogger(PrometheusStatLogger): _histogram_cls = RayHistogramWrapper _spec_decoding_cls = RaySpecDecodingProm - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - super().__init__(vllm_config, engine_index) + def __init__(self, + vllm_config: VllmConfig, + engine_indexes: Optional[list[int]] = None): + super().__init__(vllm_config, engine_indexes) @staticmethod def _unregister_vllm_metrics():