Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-07-20 15:56:50 +00:00
parent cad9670547
commit fd0650f258
3 changed files with 32 additions and 16 deletions

View File

@ -96,10 +96,17 @@ class AsyncLLM(EngineClient):
self.log_stats = log_stats self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank. # Set up stat loggers; independent set for each DP rank.
# HACK: asyncllm should not be aware of how many engines is it
# managing.
start_idx = vllm_config.parallel_config.data_parallel_rank
local_engines = vllm_config.parallel_config.data_parallel_size_local
engine_idxs = [
idx for idx in range(start_idx, start_idx + local_engines)
]
self.stat_loggers = setup_default_loggers( self.stat_loggers = setup_default_loggers(
vllm_config=vllm_config, vllm_config=vllm_config,
log_stats=self.log_stats, log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size, engine_idxs=engine_idxs,
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
) )
@ -130,9 +137,11 @@ class AsyncLLM(EngineClient):
client_index=client_index, client_index=client_index,
) )
if self.stat_loggers: if self.stat_loggers:
per_engine_loggers, _ = self.stat_loggers # loggers, prom_logger
for stat_logger in per_engine_loggers[0]: loggers, _ = self.stat_loggers
stat_logger.log_engine_initialized() for per_engine_loggers in loggers.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
try: try:
# Start output handler eagerly if we are in the asyncio eventloop. # Start output handler eagerly if we are in the asyncio eventloop.
@ -435,7 +444,8 @@ class AsyncLLM(EngineClient):
@staticmethod @staticmethod
def _record_stats( def _record_stats(
stat_loggers: tuple[list[list[StatLoggerBase]], PrometheusStatLogger], stat_loggers: tuple[dict[int, list[StatLoggerBase]],
PrometheusStatLogger],
engine_idx: int, engine_idx: int,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
@ -559,7 +569,7 @@ class AsyncLLM(EngineClient):
if self.stat_loggers is None: if self.stat_loggers is None:
return return
per_engine_loggers, _ = self.stat_loggers per_engine_loggers, _ = self.stat_loggers
for loggers in per_engine_loggers: for loggers in per_engine_loggers.values():
for stat_logger in loggers: for stat_logger in loggers:
stat_logger.log() stat_logger.log()

View File

@ -150,7 +150,11 @@ class PrometheusStatLogger(StatLoggerBase):
_histogram_cls = prometheus_client.Histogram _histogram_cls = prometheus_client.Histogram
_spec_decoding_cls = SpecDecodingProm _spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_num: int = 1): def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
if engine_indexes is None:
engine_indexes = [0]
unregister_vllm_metrics() unregister_vllm_metrics()
self.vllm_config = vllm_config self.vllm_config = vllm_config
@ -162,7 +166,6 @@ class PrometheusStatLogger(StatLoggerBase):
labelnames = ["model_name", "engine"] labelnames = ["model_name", "engine"]
model_name = vllm_config.model_config.served_model_name model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
engine_indexes = list(range(engine_num))
# self.spec_decoding_prom = self._spec_decoding_cls( # self.spec_decoding_prom = self._spec_decoding_cls(
# vllm_config.speculative_config, labelnames, labelvalues) # vllm_config.speculative_config, labelnames, labelvalues)
@ -600,9 +603,9 @@ def build_1_2_5_buckets(max_value: int) -> list[int]:
def setup_default_loggers( def setup_default_loggers(
vllm_config: VllmConfig, vllm_config: VllmConfig,
log_stats: bool, log_stats: bool,
engine_num: int, engine_idxs: list[int],
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> Optional[tuple[list[list[StatLoggerBase]], PrometheusStatLogger]]: ) -> Optional[tuple[dict[int, list[StatLoggerBase]], PrometheusStatLogger]]:
"""Setup logging and prometheus metrics.""" """Setup logging and prometheus metrics."""
if not log_stats: if not log_stats:
return None return None
@ -615,13 +618,14 @@ def setup_default_loggers(
if logger.isEnabledFor(logging.INFO): if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger) factories.append(LoggingStatLogger)
stat_loggers: list[list[StatLoggerBase]] = [] # engine_idx: Logger
for engine_idx in range(engine_num): stat_loggers: dict[int, list[StatLoggerBase]] = {}
for engine_idx in engine_idxs:
per_engine_stat_loggers: list[StatLoggerBase] = [] per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories: for logger_factory in factories:
per_engine_stat_loggers.append( per_engine_stat_loggers.append(
logger_factory(vllm_config, engine_idx)) logger_factory(vllm_config, engine_idx))
stat_loggers.append(per_engine_stat_loggers) stat_loggers[engine_idx] = per_engine_stat_loggers
prom_stat_logger = PrometheusStatLogger(vllm_config, engine_num) prom_stat_logger = PrometheusStatLogger(vllm_config, engine_idxs)
return stat_loggers, prom_stat_logger return stat_loggers, prom_stat_logger

View File

@ -122,8 +122,10 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_histogram_cls = RayHistogramWrapper _histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm _spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self,
super().__init__(vllm_config, engine_index) vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, engine_indexes)
@staticmethod @staticmethod
def _unregister_vllm_metrics(): def _unregister_vllm_metrics():