[DP] Fix Prometheus Logging (#21257)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-07-21 12:11:35 -04:00 committed by GitHub
parent 304dce7ec0
commit 29d1ffc5b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 369 additions and 249 deletions

View File

@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch):
await engine.do_log_stats()
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(stat_loggers[0]) == 1
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")

View File

@ -90,8 +90,10 @@ async def test_load(output_kind: RequestOutputKind,
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)

View File

@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__)
@ -95,14 +94,6 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
@ -121,7 +112,6 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
@ -129,9 +119,17 @@ class AsyncLLM(EngineClient):
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()
# Loggers.
self.logger_manager: Optional[StatLoggerManager] = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
@ -370,7 +368,7 @@ class AsyncLLM(EngineClient):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
logger_manager = self.logger_manager
async def output_handler():
try:
@ -410,9 +408,9 @@ class AsyncLLM(EngineClient):
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
if logger_manager:
logger_manager.record(
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
@ -431,18 +429,6 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
async def encode(
self,
prompt: PromptType,
@ -547,9 +533,8 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
if self.logger_manager:
self.logger_manager.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")
@ -653,18 +638,16 @@ class AsyncLLM(EngineClient):
new_data_parallel_size
# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
# TODO(rob): fix this after talking with Ray team.
# This resets all the prometheus metrics since we
# unregister during initialization. Need to understand
# the intended behavior here better.
self.logger_manager = StatLoggerManager(
vllm_config=self.vllm_config,
log_stats=self.log_stats,
engine_num=new_data_parallel_size,
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()
@property
def is_running(self) -> bool:

View File

@ -432,14 +432,15 @@ class MPClient(EngineCoreClient):
external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = [dp_rank] if (offline_mode
or external_dp_lb) else range(dp_size)
self.engine_ranks = ([dp_rank] if
(offline_mode or external_dp_lb) else list(
range(dp_size)))
assert parallel_config.data_parallel_size_local <= len(
engine_ranks)
self.engine_ranks)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks
index.to_bytes(2, "little") for index in self.engine_ranks
]
# Wait for ready messages from each engine on the input socket.

View File

@ -4,7 +4,7 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable, Optional, Union
import numpy as np
import prometheus_client
@ -35,8 +35,10 @@ class StatLoggerBase(ABC):
...
@abstractmethod
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
...
@abstractmethod
@ -78,8 +80,10 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
"""Log Stats to standard output."""
if iteration_stats:
@ -146,233 +150,290 @@ class PrometheusStatLogger(StatLoggerBase):
_histogram_cls = prometheus_client.Histogram
_spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
if engine_indexes is None:
engine_indexes = [0]
self.engine_indexes = engine_indexes
unregister_vllm_metrics()
self.vllm_config = vllm_config
self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in
# a previous release and which will be removed future
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name", "engine"]
labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
spec_decode_labelvalues = [
vllm_config.model_config.served_model_name,
str(self.engine_indexes[0])
]
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, labelvalues)
vllm_config.speculative_config, labelnames,
spec_decode_labelvalues)
#
# Scheduler state
#
self.gauge_scheduler_running = self._gauge_cls(
gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running,
engine_indexes,
model_name)
self.gauge_scheduler_waiting = self._gauge_cls(
gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting,
engine_indexes,
model_name)
#
# GPU cache
#
# Deprecated in 0.9 - Renamed as vllm:kv_cache_usage_perc
# TODO: in 0.10, only enable if show_hidden_metrics=True
self.gauge_gpu_cache_usage = self._gauge_cls(
gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation=(
"GPU KV-cache usage. 1 means 100 percent usage."
"DEPRECATED: Use vllm:kv_cache_usage_perc instead."),
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.gauge_gpu_cache_usage = make_per_engine(gauge_gpu_cache_usage,
engine_indexes,
model_name)
# Deprecated in 0.9 - Renamed as vllm:prefix_cache_queries
# TODO: in 0.10, only enable if show_hidden_metrics=True
self.counter_gpu_prefix_cache_queries = self._counter_cls(
counter_gpu_prefix_cache_queries = self._counter_cls(
name="vllm:gpu_prefix_cache_queries",
documentation=
("GPU prefix cache queries, in terms of number of queried tokens."
"DEPRECATED: Use vllm:prefix_cache_queries instead."),
labelnames=labelnames).labels(*labelvalues)
documentation=(
"GPU prefix cache queries, in terms of number of queried"
"tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."),
labelnames=labelnames)
self.counter_gpu_prefix_cache_queries = make_per_engine(
counter_gpu_prefix_cache_queries, engine_indexes, model_name)
# Deprecated in 0.9 - Renamed as vllm:prefix_cache_hits
# TODO: in 0.10, only enable if show_hidden_metrics=True
self.counter_gpu_prefix_cache_hits = self._counter_cls(
counter_gpu_prefix_cache_hits = self._counter_cls(
name="vllm:gpu_prefix_cache_hits",
documentation=(
"GPU prefix cache hits, in terms of number of cached tokens."
"DEPRECATED: Use vllm:prefix_cache_hits instead."),
labelnames=labelnames).labels(*labelvalues)
"GPU prefix cache hits, in terms of number of cached "
"tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."),
labelnames=labelnames)
self.counter_gpu_prefix_cache_hits = make_per_engine(
counter_gpu_prefix_cache_hits, engine_indexes, model_name)
self.gauge_kv_cache_usage = self._gauge_cls(
gauge_kv_cache_usage = self._gauge_cls(
name="vllm:kv_cache_usage_perc",
documentation="KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage,
engine_indexes, model_name)
self.counter_prefix_cache_queries = self._counter_cls(
counter_prefix_cache_queries = self._counter_cls(
name="vllm:prefix_cache_queries",
documentation=(
"Prefix cache queries, in terms of number of queried tokens."),
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.counter_prefix_cache_queries = make_per_engine(
counter_prefix_cache_queries, engine_indexes, model_name)
self.counter_prefix_cache_hits = self._counter_cls(
counter_prefix_cache_hits = self._counter_cls(
name="vllm:prefix_cache_hits",
documentation=(
"Prefix cache hits, in terms of number of cached tokens."),
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.counter_prefix_cache_hits = make_per_engine(
counter_prefix_cache_hits, engine_indexes, model_name)
#
# Counters
#
self.counter_num_preempted_reqs = self._counter_cls(
counter_num_preempted_reqs = self._counter_cls(
name="vllm:num_preemptions",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.counter_num_preempted_reqs = make_per_engine(
counter_num_preempted_reqs, engine_indexes, model_name)
self.counter_prompt_tokens = self._counter_cls(
counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens,
engine_indexes,
model_name)
self.counter_generation_tokens = self._counter_cls(
counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames).labels(*labelvalues)
labelnames=labelnames)
self.counter_generation_tokens = make_per_engine(
counter_generation_tokens, engine_indexes, model_name)
self.counter_request_success: dict[FinishReason,
prometheus_client.Counter] = {}
self.counter_request_success: dict[FinishReason, dict[
int, prometheus_client.Counter]] = {}
counter_request_success_base = self._counter_cls(
name="vllm:request_success",
documentation="Count of successfully processed requests.",
labelnames=labelnames + ["finished_reason"])
for reason in FinishReason:
self.counter_request_success[
reason] = counter_request_success_base.labels(*(labelvalues +
[str(reason)]))
self.counter_request_success[reason] = {
idx:
counter_request_success_base.labels(model_name, str(idx),
str(reason))
for idx in engine_indexes
}
#
# Histograms of counts
#
self.histogram_num_prompt_tokens_request = \
self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames)
self.histogram_num_prompt_tokens_request = make_per_engine(
histogram_num_prompt_tokens_request, engine_indexes, model_name)
self.histogram_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
histogram_num_generation_tokens_request = self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames)
self.histogram_num_generation_tokens_request = make_per_engine(
histogram_num_generation_tokens_request, engine_indexes,
model_name)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \
self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
16384
],
labelnames=labelnames).labels(*labelvalues)
histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
],
labelnames=labelnames)
self.histogram_iteration_tokens = make_per_engine(
histogram_iteration_tokens, engine_indexes, model_name)
self.histogram_max_num_generation_tokens_request = \
self._histogram_cls(
name="vllm:request_max_num_generation_tokens",
documentation=
"Histogram of maximum number of requested generation tokens.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
histogram_max_num_generation_tokens_request = self._histogram_cls(
name="vllm:request_max_num_generation_tokens",
documentation=
"Histogram of maximum number of requested generation tokens.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames)
self.histogram_max_num_generation_tokens_request = make_per_engine(
histogram_max_num_generation_tokens_request, engine_indexes,
model_name)
self.histogram_n_request = \
self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames).labels(*labelvalues)
histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames)
self.histogram_n_request = make_per_engine(histogram_n_request,
engine_indexes, model_name)
self.histogram_max_tokens_request = \
self._histogram_cls(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues)
histogram_max_tokens_request = self._histogram_cls(
name="vllm:request_params_max_tokens",
documentation="Histogram of the max_tokens request parameter.",
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames)
self.histogram_max_tokens_request = make_per_engine(
histogram_max_tokens_request, engine_indexes, model_name)
#
# Histogram of timing intervals
#
self.histogram_time_to_first_token = \
self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
640.0, 2560.0
],
labelnames=labelnames).labels(*labelvalues)
histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
2560.0
],
labelnames=labelnames)
self.histogram_time_to_first_token = make_per_engine(
histogram_time_to_first_token, engine_indexes, model_name)
self.histogram_time_per_output_token = \
self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames).labels(*labelvalues)
histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=labelnames)
self.histogram_time_per_output_token = make_per_engine(
histogram_time_per_output_token, engine_indexes, model_name)
request_latency_buckets = [
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
]
self.histogram_e2e_time_request = \
self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of e2e request latency in seconds.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_queue_time_request = \
self._histogram_cls(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_inference_time_request = \
self._histogram_cls(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_prefill_time_request = \
self._histogram_cls(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
self.histogram_decode_time_request = \
self._histogram_cls(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues)
histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of e2e request latency in seconds.",
buckets=request_latency_buckets,
labelnames=labelnames)
self.histogram_e2e_time_request = make_per_engine(
histogram_e2e_time_request, engine_indexes, model_name)
histogram_queue_time_request = self._histogram_cls(
name="vllm:request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames)
self.histogram_queue_time_request = make_per_engine(
histogram_queue_time_request, engine_indexes, model_name)
histogram_inference_time_request = self._histogram_cls(
name="vllm:request_inference_time_seconds",
documentation=
"Histogram of time spent in RUNNING phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames)
self.histogram_inference_time_request = make_per_engine(
histogram_inference_time_request, engine_indexes, model_name)
histogram_prefill_time_request = self._histogram_cls(
name="vllm:request_prefill_time_seconds",
documentation=
"Histogram of time spent in PREFILL phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames)
self.histogram_prefill_time_request = make_per_engine(
histogram_prefill_time_request, engine_indexes, model_name)
histogram_decode_time_request = self._histogram_cls(
name="vllm:request_decode_time_seconds",
documentation=
"Histogram of time spent in DECODE phase for request.",
buckets=request_latency_buckets,
labelnames=labelnames)
self.histogram_decode_time_request = make_per_engine(
histogram_decode_time_request, engine_indexes, model_name)
#
# LoRA metrics
@ -382,6 +443,9 @@ class PrometheusStatLogger(StatLoggerBase):
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None:
if len(self.engine_indexes) > 1:
raise NotImplementedError(
"LoRA in DP mode is not supported yet.")
self.labelname_max_lora = "max_lora"
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
self.labelname_running_lora_adapters = "running_lora_adapters"
@ -399,9 +463,8 @@ class PrometheusStatLogger(StatLoggerBase):
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index
metrics_info["engine"] = ""
name, documentation = None, None
if type == "cache_config":
@ -417,27 +480,36 @@ class PrometheusStatLogger(StatLoggerBase):
documentation=documentation,
multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1)
)
for engine_index in self.engine_indexes:
metrics_info = config_obj.metrics_info()
metrics_info["engine"] = str(engine_index)
info_gauge.labels(**metrics_info).set(1)
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
def record(self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0):
"""Log to prometheus."""
if scheduler_stats is not None:
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
self.gauge_scheduler_running[engine_idx].set(
scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting[engine_idx].set(
scheduler_stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(scheduler_stats.kv_cache_usage)
self.gauge_kv_cache_usage.set(scheduler_stats.kv_cache_usage)
self.gauge_gpu_cache_usage[engine_idx].set(
scheduler_stats.kv_cache_usage)
self.gauge_kv_cache_usage[engine_idx].set(
scheduler_stats.kv_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
self.counter_gpu_prefix_cache_queries[engine_idx].inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
self.counter_gpu_prefix_cache_hits[engine_idx].inc(
scheduler_stats.prefix_cache_stats.hits)
self.counter_prefix_cache_queries.inc(
self.counter_prefix_cache_queries[engine_idx].inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_prefix_cache_hits.inc(
self.counter_prefix_cache_hits[engine_idx].inc(
scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None:
@ -447,42 +519,45 @@ class PrometheusStatLogger(StatLoggerBase):
if iteration_stats is None:
return
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc(
self.counter_num_preempted_reqs[engine_idx].inc(
iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens[engine_idx].inc(
iteration_stats.num_prompt_tokens)
self.counter_generation_tokens[engine_idx].inc(
iteration_stats.num_generation_tokens)
self.histogram_iteration_tokens.observe(
self.histogram_iteration_tokens[engine_idx].observe(
iteration_stats.num_prompt_tokens + \
iteration_stats.num_generation_tokens)
for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
self.histogram_max_num_generation_tokens_request.observe(
max_gen_tokens)
self.histogram_max_num_generation_tokens_request[
engine_idx].observe(max_gen_tokens)
for n_param in iteration_stats.n_params_iter:
self.histogram_n_request.observe(n_param)
self.histogram_n_request[engine_idx].observe(n_param)
for ttft in iteration_stats.time_to_first_tokens_iter:
self.histogram_time_to_first_token.observe(ttft)
self.histogram_time_to_first_token[engine_idx].observe(ttft)
for tpot in iteration_stats.time_per_output_tokens_iter:
self.histogram_time_per_output_token.observe(tpot)
self.histogram_time_per_output_token[engine_idx].observe(tpot)
for finished_request in iteration_stats.finished_requests:
self.counter_request_success[finished_request.finish_reason].inc()
self.histogram_e2e_time_request.observe(
self.counter_request_success[
finished_request.finish_reason][engine_idx].inc()
self.histogram_e2e_time_request[engine_idx].observe(
finished_request.e2e_latency)
self.histogram_queue_time_request.observe(
self.histogram_queue_time_request[engine_idx].observe(
finished_request.queued_time)
self.histogram_prefill_time_request.observe(
self.histogram_prefill_time_request[engine_idx].observe(
finished_request.prefill_time)
self.histogram_inference_time_request.observe(
self.histogram_inference_time_request[engine_idx].observe(
finished_request.inference_time)
self.histogram_decode_time_request.observe(
self.histogram_decode_time_request[engine_idx].observe(
finished_request.decode_time)
self.histogram_num_prompt_tokens_request.observe(
self.histogram_num_prompt_tokens_request[engine_idx].observe(
finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe(
self.histogram_num_generation_tokens_request[engine_idx].observe(
finished_request.num_generation_tokens)
if finished_request.max_tokens_param:
self.histogram_max_tokens_request.observe(
self.histogram_max_tokens_request[engine_idx].observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None:
@ -502,6 +577,18 @@ class PrometheusStatLogger(StatLoggerBase):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
PromMetric = Union[
prometheus_client.Gauge,
prometheus_client.Counter,
prometheus_client.Histogram,
]
def make_per_engine(metric: PromMetric, engine_idxs: list[int],
model_name: str) -> dict[int, PromMetric]:
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
@ -529,29 +616,79 @@ def build_1_2_5_buckets(max_value: int) -> list[int]:
return build_buckets([1, 2, 5], max_value)
def setup_default_loggers(
vllm_config: VllmConfig,
log_stats: bool,
engine_num: int,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
) -> list[list[StatLoggerBase]]:
"""Setup logging and prometheus metrics."""
if not log_stats:
return []
class StatLoggerManager:
"""
StatLoggerManager:
Logging happens at the level of the EngineCore (per scheduler).
* DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
* With Local Logger, just make N copies for N EngineCores.
* With Prometheus, we need a single logger with N "labels"
factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = [PrometheusStatLogger]
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
This class abstracts away this implementation detail from
the AsyncLLM, allowing the AsyncLLM to just call .record()
and .log() to a simple interface.
"""
stat_loggers: list[list[StatLoggerBase]] = []
for i in range(engine_num):
per_engine_stat_loggers: list[StatLoggerBase] = []
for logger_factory in factories:
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
stat_loggers.append(per_engine_stat_loggers)
def __init__(
self,
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
):
self.engine_idxs = engine_idxs if engine_idxs else [0]
return stat_loggers
factories: list[StatLoggerFactory]
if custom_stat_loggers is not None:
factories = custom_stat_loggers
else:
factories = []
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)
# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
prometheus_factory = PrometheusStatLogger
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
# If we get a custom prometheus logger, use that
# instead. This is typically used for the ray case.
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
prometheus_factory = logger_factory
continue
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers
# For Prometheus, need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: Optional[int] = None,
):
if engine_idx is None:
engine_idx = 0
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()

View File

@ -3,7 +3,6 @@
import time
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
@ -128,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
super().__init__(vllm_config, engine_index)
@staticmethod
def _unregister_vllm_metrics():
# No-op on purpose