mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 05:15:42 +08:00
updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
896b0a271e
commit
54e405bd92
@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
|
from vllm.v1.metrics.loggers import StatLoggerFactory, setup_default_loggers
|
||||||
StatLoggerFactory, setup_default_loggers)
|
|
||||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -103,7 +102,7 @@ class AsyncLLM(EngineClient):
|
|||||||
engine_idxs = [
|
engine_idxs = [
|
||||||
idx for idx in range(start_idx, start_idx + local_engines)
|
idx for idx in range(start_idx, start_idx + local_engines)
|
||||||
]
|
]
|
||||||
self.stat_loggers = setup_default_loggers(
|
self.logger_manager = setup_default_loggers(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
engine_idxs=engine_idxs,
|
engine_idxs=engine_idxs,
|
||||||
@ -136,12 +135,8 @@ class AsyncLLM(EngineClient):
|
|||||||
client_addresses=client_addresses,
|
client_addresses=client_addresses,
|
||||||
client_index=client_index,
|
client_index=client_index,
|
||||||
)
|
)
|
||||||
if self.stat_loggers:
|
if self.logger_manager:
|
||||||
# loggers, prom_logger
|
self.logger_manager.log_engine_initialized()
|
||||||
loggers, _ = self.stat_loggers
|
|
||||||
for per_engine_loggers in loggers.values():
|
|
||||||
for logger in per_engine_loggers:
|
|
||||||
logger.log_engine_initialized()
|
|
||||||
self.output_handler: Optional[asyncio.Task] = None
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
try:
|
try:
|
||||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||||
@ -380,7 +375,7 @@ class AsyncLLM(EngineClient):
|
|||||||
engine_core = self.engine_core
|
engine_core = self.engine_core
|
||||||
output_processor = self.output_processor
|
output_processor = self.output_processor
|
||||||
log_stats = self.log_stats
|
log_stats = self.log_stats
|
||||||
stat_loggers = self.stat_loggers if log_stats else None
|
logger_manager = self.logger_manager
|
||||||
|
|
||||||
async def output_handler():
|
async def output_handler():
|
||||||
try:
|
try:
|
||||||
@ -420,13 +415,14 @@ class AsyncLLM(EngineClient):
|
|||||||
# 4) Logging.
|
# 4) Logging.
|
||||||
# TODO(rob): make into a coroutine and launch it in
|
# TODO(rob): make into a coroutine and launch it in
|
||||||
# background thread once Prometheus overhead is non-trivial.
|
# background thread once Prometheus overhead is non-trivial.
|
||||||
if stat_loggers:
|
# NOTE: we do not use self.log
|
||||||
AsyncLLM._record_stats(
|
if logger_manager:
|
||||||
stat_loggers,
|
logger_manager.record(
|
||||||
outputs.engine_index,
|
|
||||||
scheduler_stats=outputs.scheduler_stats,
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
iteration_stats=iteration_stats,
|
iteration_stats=iteration_stats,
|
||||||
|
engine_idx=outputs.engine_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("AsyncLLM output_handler failed.")
|
logger.exception("AsyncLLM output_handler failed.")
|
||||||
output_processor.propagate_error(e)
|
output_processor.propagate_error(e)
|
||||||
@ -442,26 +438,6 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Aborted request %s.", request_id)
|
logger.info("Aborted request %s.", request_id)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _record_stats(
|
|
||||||
stat_loggers: tuple[dict[int, list[StatLoggerBase]],
|
|
||||||
PrometheusStatLogger],
|
|
||||||
engine_idx: int,
|
|
||||||
scheduler_stats: Optional[SchedulerStats],
|
|
||||||
iteration_stats: Optional[IterationStats],
|
|
||||||
):
|
|
||||||
"""static so that it can be used from the output_handler task
|
|
||||||
without a circular ref to AsyncLLM."""
|
|
||||||
|
|
||||||
per_engine_loggers, prom_logger = stat_loggers
|
|
||||||
for stat_logger in per_engine_loggers[engine_idx]:
|
|
||||||
stat_logger.record(engine_idx=engine_idx,
|
|
||||||
scheduler_stats=scheduler_stats,
|
|
||||||
iteration_stats=iteration_stats)
|
|
||||||
prom_logger.record(engine_idx=engine_idx,
|
|
||||||
scheduler_stats=scheduler_stats,
|
|
||||||
iteration_stats=iteration_stats)
|
|
||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
|
|||||||
@ -605,27 +605,75 @@ def setup_default_loggers(
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
engine_idxs: list[int],
|
engine_idxs: list[int],
|
||||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
) -> Optional[tuple[dict[int, list[StatLoggerBase]], PrometheusStatLogger]]:
|
) -> Optional["StatLoggerManager"]:
|
||||||
"""Setup logging and prometheus metrics."""
|
"""Setup logging and prometheus metrics."""
|
||||||
if not log_stats:
|
return (None if not log_stats else StatLoggerManager(
|
||||||
return None
|
vllm_config, engine_idxs, custom_stat_loggers))
|
||||||
|
|
||||||
factories: list[StatLoggerFactory]
|
|
||||||
if custom_stat_loggers is not None:
|
|
||||||
factories = custom_stat_loggers
|
|
||||||
else:
|
|
||||||
factories = []
|
|
||||||
if logger.isEnabledFor(logging.INFO):
|
|
||||||
factories.append(LoggingStatLogger)
|
|
||||||
|
|
||||||
# engine_idx: Logger
|
class StatLoggerManager:
|
||||||
stat_loggers: dict[int, list[StatLoggerBase]] = {}
|
"""
|
||||||
for engine_idx in engine_idxs:
|
StatLoggerManager:
|
||||||
per_engine_stat_loggers: list[StatLoggerBase] = []
|
Logging happens at the level of the EngineCore (per scheduler).
|
||||||
for logger_factory in factories:
|
* DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
|
||||||
per_engine_stat_loggers.append(
|
* With Local Logger, just make N copies for N EngineCores.
|
||||||
logger_factory(vllm_config, engine_idx))
|
* With Prometheus, we need a single logger with N "labels"
|
||||||
stat_loggers[engine_idx] = per_engine_stat_loggers
|
|
||||||
|
|
||||||
prom_stat_logger = PrometheusStatLogger(vllm_config, engine_idxs)
|
This class abstracts away this implementation detail from
|
||||||
return stat_loggers, prom_stat_logger
|
the AsyncLLM, allowing the AsyncLLM to just call .record()
|
||||||
|
and .log() to a simple interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
engine_idxs: Optional[list[int]] = None,
|
||||||
|
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||||
|
):
|
||||||
|
self.engine_idxs = set([0]) if not engine_idxs else set(engine_idxs)
|
||||||
|
|
||||||
|
factories: list[StatLoggerFactory]
|
||||||
|
if custom_stat_loggers is not None:
|
||||||
|
factories = custom_stat_loggers
|
||||||
|
else:
|
||||||
|
factories = []
|
||||||
|
if logger.isEnabledFor(logging.INFO):
|
||||||
|
factories.append(LoggingStatLogger)
|
||||||
|
|
||||||
|
# engine_idx: StatLogger
|
||||||
|
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||||
|
for engine_idx in self.engine_idxs:
|
||||||
|
loggers: list[StatLoggerBase] = []
|
||||||
|
for logger_factory in factories:
|
||||||
|
loggers.append(logger_factory(vllm_config, engine_idx))
|
||||||
|
self.per_engine_logger_dict[engine_idx] = loggers
|
||||||
|
|
||||||
|
# For Prometheus, need to share the metrics between EngineCores.
|
||||||
|
# Each EngineCore's metrics are expressed as a unique label.
|
||||||
|
self.prometheus_logger = PrometheusStatLogger(vllm_config, engine_idxs)
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
scheduler_stats: Optional[SchedulerStats],
|
||||||
|
iteration_stats: Optional[IterationStats],
|
||||||
|
engine_idx: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if engine_idx is None:
|
||||||
|
engine_idx = 0
|
||||||
|
|
||||||
|
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||||
|
for logger in per_engine_loggers:
|
||||||
|
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||||
|
|
||||||
|
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||||
|
engine_idx)
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||||
|
for logger in per_engine_loggers:
|
||||||
|
logger.log()
|
||||||
|
|
||||||
|
def log_engine_initialized(self):
|
||||||
|
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||||
|
for logger in per_engine_loggers:
|
||||||
|
logger.log_engine_initialized()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user