mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 17:26:01 +08:00
[DP] Fix Prometheus Logging (#21257)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
304dce7ec0
commit
29d1ffc5b4
@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch):
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
assert len(engine.stat_loggers) == 1
|
||||
assert len(engine.stat_loggers[0]) == 1
|
||||
engine.stat_loggers[0][0].log.assert_called_once()
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(stat_loggers[0]) == 1
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
|
||||
@ -90,8 +90,10 @@ async def test_load(output_kind: RequestOutputKind,
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
stats_loggers[engine_index] = self
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
def record(self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0):
|
||||
if iteration_stats:
|
||||
self.finished_req_count += len(
|
||||
iteration_stats.finished_requests)
|
||||
|
||||
@ -36,10 +36,9 @@ from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
|
||||
setup_default_loggers)
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.prometheus import shutdown_prometheus
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -95,14 +94,6 @@ class AsyncLLM(EngineClient):
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Set up stat loggers; independent set for each DP rank.
|
||||
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
vllm_config=vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=vllm_config.parallel_config.data_parallel_size,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
@ -121,7 +112,6 @@ class AsyncLLM(EngineClient):
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
@ -129,9 +119,17 @@ class AsyncLLM(EngineClient):
|
||||
client_addresses=client_addresses,
|
||||
client_index=client_index,
|
||||
)
|
||||
if self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[0]:
|
||||
stat_logger.log_engine_initialized()
|
||||
|
||||
# Loggers.
|
||||
self.logger_manager: Optional[StatLoggerManager] = None
|
||||
if self.log_stats:
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=vllm_config,
|
||||
engine_idxs=self.engine_core.engine_ranks,
|
||||
custom_stat_loggers=stat_loggers,
|
||||
)
|
||||
self.logger_manager.log_engine_initialized()
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
try:
|
||||
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||
@ -370,7 +368,7 @@ class AsyncLLM(EngineClient):
|
||||
engine_core = self.engine_core
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
stat_loggers = self.stat_loggers if log_stats else None
|
||||
logger_manager = self.logger_manager
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
@ -410,9 +408,9 @@ class AsyncLLM(EngineClient):
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
if stat_loggers:
|
||||
AsyncLLM._record_stats(
|
||||
stat_loggers[outputs.engine_index],
|
||||
if logger_manager:
|
||||
logger_manager.record(
|
||||
engine_idx=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
@ -431,18 +429,6 @@ class AsyncLLM(EngineClient):
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@staticmethod
|
||||
def _record_stats(
|
||||
stat_loggers: list[StatLoggerBase],
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
"""static so that it can be used from the output_handler task
|
||||
without a circular ref to AsyncLLM."""
|
||||
for stat_logger in stat_loggers:
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
@ -547,9 +533,8 @@ class AsyncLLM(EngineClient):
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
for loggers in self.stat_loggers:
|
||||
for stat_logger in loggers:
|
||||
stat_logger.log()
|
||||
if self.logger_manager:
|
||||
self.logger_manager.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
@ -653,18 +638,16 @@ class AsyncLLM(EngineClient):
|
||||
new_data_parallel_size
|
||||
|
||||
# recreate stat loggers
|
||||
if new_data_parallel_size > old_data_parallel_size:
|
||||
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
|
||||
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
|
||||
# TODO(rob): fix this after talking with Ray team.
|
||||
# This resets all the prometheus metrics since we
|
||||
# unregister during initialization. Need to understand
|
||||
# the intended behavior here better.
|
||||
self.logger_manager = StatLoggerManager(
|
||||
vllm_config=self.vllm_config,
|
||||
log_stats=self.log_stats,
|
||||
engine_num=new_data_parallel_size,
|
||||
engine_idxs=list(range(new_data_parallel_size)),
|
||||
custom_stat_loggers=None,
|
||||
)
|
||||
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
|
||||
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
|
||||
else:
|
||||
for _ in range(old_data_parallel_size - new_data_parallel_size):
|
||||
self.stat_loggers.pop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
|
||||
@ -432,14 +432,15 @@ class MPClient(EngineCoreClient):
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
offline_mode = parallel_config.data_parallel_rank_local is not None
|
||||
engine_ranks = [dp_rank] if (offline_mode
|
||||
or external_dp_lb) else range(dp_size)
|
||||
self.engine_ranks = ([dp_rank] if
|
||||
(offline_mode or external_dp_lb) else list(
|
||||
range(dp_size)))
|
||||
assert parallel_config.data_parallel_size_local <= len(
|
||||
engine_ranks)
|
||||
self.engine_ranks)
|
||||
|
||||
# ZMQ identity of each engine that this client will talk to.
|
||||
self.core_engines: list[EngineIdentity] = [
|
||||
index.to_bytes(2, "little") for index in engine_ranks
|
||||
index.to_bytes(2, "little") for index in self.engine_ranks
|
||||
]
|
||||
|
||||
# Wait for ready messages from each engine on the input socket.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import prometheus_client
|
||||
@ -35,8 +35,10 @@ class StatLoggerBase(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
def record(self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@ -78,8 +80,10 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
# Compute summary metrics for tracked stats
|
||||
return float(np.sum(tracked_stats) / (now - self.last_log_time))
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
def record(self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0):
|
||||
"""Log Stats to standard output."""
|
||||
|
||||
if iteration_stats:
|
||||
@ -146,233 +150,290 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
_spec_decoding_cls = SpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
if engine_indexes is None:
|
||||
engine_indexes = [0]
|
||||
self.engine_indexes = engine_indexes
|
||||
|
||||
unregister_vllm_metrics()
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_index = engine_index
|
||||
# Use this flag to hide metrics that were deprecated in
|
||||
# a previous release and which will be removed future
|
||||
self.show_hidden_metrics = \
|
||||
vllm_config.observability_config.show_hidden_metrics
|
||||
|
||||
labelnames = ["model_name", "engine"]
|
||||
labelvalues = [
|
||||
vllm_config.model_config.served_model_name,
|
||||
str(engine_index)
|
||||
]
|
||||
|
||||
model_name = vllm_config.model_config.served_model_name
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
if (len(self.engine_indexes) > 1
|
||||
and vllm_config.speculative_config is not None):
|
||||
raise NotImplementedError("Prometheus metrics with Spec Decoding "
|
||||
"with >1 EngineCore per AsyncLLM is not "
|
||||
"supported yet.")
|
||||
spec_decode_labelvalues = [
|
||||
vllm_config.model_config.served_model_name,
|
||||
str(self.engine_indexes[0])
|
||||
]
|
||||
self.spec_decoding_prom = self._spec_decoding_cls(
|
||||
vllm_config.speculative_config, labelnames, labelvalues)
|
||||
vllm_config.speculative_config, labelnames,
|
||||
spec_decode_labelvalues)
|
||||
|
||||
#
|
||||
# Scheduler state
|
||||
#
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests in model execution batches.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running,
|
||||
engine_indexes,
|
||||
model_name)
|
||||
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting,
|
||||
engine_indexes,
|
||||
model_name)
|
||||
|
||||
#
|
||||
# GPU cache
|
||||
#
|
||||
# Deprecated in 0.9 - Renamed as vllm:kv_cache_usage_perc
|
||||
# TODO: in 0.10, only enable if show_hidden_metrics=True
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation=(
|
||||
"GPU KV-cache usage. 1 means 100 percent usage."
|
||||
"DEPRECATED: Use vllm:kv_cache_usage_perc instead."),
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.gauge_gpu_cache_usage = make_per_engine(gauge_gpu_cache_usage,
|
||||
engine_indexes,
|
||||
model_name)
|
||||
|
||||
# Deprecated in 0.9 - Renamed as vllm:prefix_cache_queries
|
||||
# TODO: in 0.10, only enable if show_hidden_metrics=True
|
||||
self.counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||
counter_gpu_prefix_cache_queries = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_queries",
|
||||
documentation=
|
||||
("GPU prefix cache queries, in terms of number of queried tokens."
|
||||
"DEPRECATED: Use vllm:prefix_cache_queries instead."),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
documentation=(
|
||||
"GPU prefix cache queries, in terms of number of queried"
|
||||
"tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."),
|
||||
labelnames=labelnames)
|
||||
self.counter_gpu_prefix_cache_queries = make_per_engine(
|
||||
counter_gpu_prefix_cache_queries, engine_indexes, model_name)
|
||||
|
||||
# Deprecated in 0.9 - Renamed as vllm:prefix_cache_hits
|
||||
# TODO: in 0.10, only enable if show_hidden_metrics=True
|
||||
self.counter_gpu_prefix_cache_hits = self._counter_cls(
|
||||
counter_gpu_prefix_cache_hits = self._counter_cls(
|
||||
name="vllm:gpu_prefix_cache_hits",
|
||||
documentation=(
|
||||
"GPU prefix cache hits, in terms of number of cached tokens."
|
||||
"DEPRECATED: Use vllm:prefix_cache_hits instead."),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
"GPU prefix cache hits, in terms of number of cached "
|
||||
"tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."),
|
||||
labelnames=labelnames)
|
||||
self.counter_gpu_prefix_cache_hits = make_per_engine(
|
||||
counter_gpu_prefix_cache_hits, engine_indexes, model_name)
|
||||
|
||||
self.gauge_kv_cache_usage = self._gauge_cls(
|
||||
gauge_kv_cache_usage = self._gauge_cls(
|
||||
name="vllm:kv_cache_usage_perc",
|
||||
documentation="KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage,
|
||||
engine_indexes, model_name)
|
||||
|
||||
self.counter_prefix_cache_queries = self._counter_cls(
|
||||
counter_prefix_cache_queries = self._counter_cls(
|
||||
name="vllm:prefix_cache_queries",
|
||||
documentation=(
|
||||
"Prefix cache queries, in terms of number of queried tokens."),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.counter_prefix_cache_queries = make_per_engine(
|
||||
counter_prefix_cache_queries, engine_indexes, model_name)
|
||||
|
||||
self.counter_prefix_cache_hits = self._counter_cls(
|
||||
counter_prefix_cache_hits = self._counter_cls(
|
||||
name="vllm:prefix_cache_hits",
|
||||
documentation=(
|
||||
"Prefix cache hits, in terms of number of cached tokens."),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.counter_prefix_cache_hits = make_per_engine(
|
||||
counter_prefix_cache_hits, engine_indexes, model_name)
|
||||
|
||||
#
|
||||
# Counters
|
||||
#
|
||||
self.counter_num_preempted_reqs = self._counter_cls(
|
||||
counter_num_preempted_reqs = self._counter_cls(
|
||||
name="vllm:num_preemptions",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.counter_num_preempted_reqs = make_per_engine(
|
||||
counter_num_preempted_reqs, engine_indexes, model_name)
|
||||
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens,
|
||||
engine_indexes,
|
||||
model_name)
|
||||
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = make_per_engine(
|
||||
counter_generation_tokens, engine_indexes, model_name)
|
||||
|
||||
self.counter_request_success: dict[FinishReason,
|
||||
prometheus_client.Counter] = {}
|
||||
self.counter_request_success: dict[FinishReason, dict[
|
||||
int, prometheus_client.Counter]] = {}
|
||||
counter_request_success_base = self._counter_cls(
|
||||
name="vllm:request_success",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + ["finished_reason"])
|
||||
for reason in FinishReason:
|
||||
self.counter_request_success[
|
||||
reason] = counter_request_success_base.labels(*(labelvalues +
|
||||
[str(reason)]))
|
||||
self.counter_request_success[reason] = {
|
||||
idx:
|
||||
counter_request_success_base.labels(model_name, str(idx),
|
||||
str(reason))
|
||||
for idx in engine_indexes
|
||||
}
|
||||
|
||||
#
|
||||
# Histograms of counts
|
||||
#
|
||||
self.histogram_num_prompt_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_num_prompt_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames)
|
||||
self.histogram_num_prompt_tokens_request = make_per_engine(
|
||||
histogram_num_prompt_tokens_request, engine_indexes, model_name)
|
||||
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_num_generation_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames)
|
||||
self.histogram_num_generation_tokens_request = make_per_engine(
|
||||
histogram_num_generation_tokens_request, engine_indexes,
|
||||
model_name)
|
||||
|
||||
# TODO: This metric might be incorrect in case of using multiple
|
||||
# api_server counts which uses prometheus mp.
|
||||
# See: https://github.com/vllm-project/vllm/pull/18053
|
||||
self.histogram_iteration_tokens = \
|
||||
self._histogram_cls(
|
||||
name="vllm:iteration_tokens_total",
|
||||
documentation="Histogram of number of tokens per engine_step.",
|
||||
buckets=[
|
||||
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192,
|
||||
16384
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_iteration_tokens = self._histogram_cls(
|
||||
name="vllm:iteration_tokens_total",
|
||||
documentation="Histogram of number of tokens per engine_step.",
|
||||
buckets=[
|
||||
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
|
||||
],
|
||||
labelnames=labelnames)
|
||||
self.histogram_iteration_tokens = make_per_engine(
|
||||
histogram_iteration_tokens, engine_indexes, model_name)
|
||||
|
||||
self.histogram_max_num_generation_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_max_num_generation_tokens",
|
||||
documentation=
|
||||
"Histogram of maximum number of requested generation tokens.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_max_num_generation_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_max_num_generation_tokens",
|
||||
documentation=
|
||||
"Histogram of maximum number of requested generation tokens.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames)
|
||||
self.histogram_max_num_generation_tokens_request = make_per_engine(
|
||||
histogram_max_num_generation_tokens_request, engine_indexes,
|
||||
model_name)
|
||||
|
||||
self.histogram_n_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_n_request = self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
labelnames=labelnames)
|
||||
self.histogram_n_request = make_per_engine(histogram_n_request,
|
||||
engine_indexes, model_name)
|
||||
|
||||
self.histogram_max_tokens_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_params_max_tokens",
|
||||
documentation="Histogram of the max_tokens request parameter.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_max_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_params_max_tokens",
|
||||
documentation="Histogram of the max_tokens request parameter.",
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
labelnames=labelnames)
|
||||
self.histogram_max_tokens_request = make_per_engine(
|
||||
histogram_max_tokens_request, engine_indexes, model_name)
|
||||
|
||||
#
|
||||
# Histogram of timing intervals
|
||||
#
|
||||
self.histogram_time_to_first_token = \
|
||||
self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
|
||||
640.0, 2560.0
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_time_to_first_token = self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
|
||||
2560.0
|
||||
],
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = make_per_engine(
|
||||
histogram_time_to_first_token, engine_indexes, model_name)
|
||||
|
||||
self.histogram_time_per_output_token = \
|
||||
self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_time_per_output_token = self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
|
||||
],
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_per_output_token = make_per_engine(
|
||||
histogram_time_per_output_token, engine_indexes, model_name)
|
||||
|
||||
request_latency_buckets = [
|
||||
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
|
||||
]
|
||||
self.histogram_e2e_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of e2e request latency in seconds.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_queue_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_inference_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_inference_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in RUNNING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_prefill_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_prefill_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in PREFILL phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_decode_time_request = \
|
||||
self._histogram_cls(
|
||||
name="vllm:request_decode_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in DECODE phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
histogram_e2e_time_request = self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of e2e request latency in seconds.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames)
|
||||
self.histogram_e2e_time_request = make_per_engine(
|
||||
histogram_e2e_time_request, engine_indexes, model_name)
|
||||
|
||||
histogram_queue_time_request = self._histogram_cls(
|
||||
name="vllm:request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames)
|
||||
self.histogram_queue_time_request = make_per_engine(
|
||||
histogram_queue_time_request, engine_indexes, model_name)
|
||||
|
||||
histogram_inference_time_request = self._histogram_cls(
|
||||
name="vllm:request_inference_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in RUNNING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames)
|
||||
self.histogram_inference_time_request = make_per_engine(
|
||||
histogram_inference_time_request, engine_indexes, model_name)
|
||||
|
||||
histogram_prefill_time_request = self._histogram_cls(
|
||||
name="vllm:request_prefill_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in PREFILL phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames)
|
||||
self.histogram_prefill_time_request = make_per_engine(
|
||||
histogram_prefill_time_request, engine_indexes, model_name)
|
||||
|
||||
histogram_decode_time_request = self._histogram_cls(
|
||||
name="vllm:request_decode_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in DECODE phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames)
|
||||
self.histogram_decode_time_request = make_per_engine(
|
||||
histogram_decode_time_request, engine_indexes, model_name)
|
||||
|
||||
#
|
||||
# LoRA metrics
|
||||
@ -382,6 +443,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
# api_server counts which uses prometheus mp.
|
||||
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
|
||||
if vllm_config.lora_config is not None:
|
||||
if len(self.engine_indexes) > 1:
|
||||
raise NotImplementedError(
|
||||
"LoRA in DP mode is not supported yet.")
|
||||
self.labelname_max_lora = "max_lora"
|
||||
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
|
||||
self.labelname_running_lora_adapters = "running_lora_adapters"
|
||||
@ -399,9 +463,8 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
)
|
||||
|
||||
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
|
||||
|
||||
metrics_info = config_obj.metrics_info()
|
||||
metrics_info["engine"] = self.engine_index
|
||||
metrics_info["engine"] = ""
|
||||
|
||||
name, documentation = None, None
|
||||
if type == "cache_config":
|
||||
@ -417,27 +480,36 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
documentation=documentation,
|
||||
multiprocess_mode="mostrecent",
|
||||
labelnames=metrics_info.keys(),
|
||||
).labels(**metrics_info)
|
||||
info_gauge.set(1)
|
||||
)
|
||||
for engine_index in self.engine_indexes:
|
||||
metrics_info = config_obj.metrics_info()
|
||||
metrics_info["engine"] = str(engine_index)
|
||||
info_gauge.labels(**metrics_info).set(1)
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
def record(self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0):
|
||||
"""Log to prometheus."""
|
||||
if scheduler_stats is not None:
|
||||
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
|
||||
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
|
||||
self.gauge_scheduler_running[engine_idx].set(
|
||||
scheduler_stats.num_running_reqs)
|
||||
self.gauge_scheduler_waiting[engine_idx].set(
|
||||
scheduler_stats.num_waiting_reqs)
|
||||
|
||||
self.gauge_gpu_cache_usage.set(scheduler_stats.kv_cache_usage)
|
||||
self.gauge_kv_cache_usage.set(scheduler_stats.kv_cache_usage)
|
||||
self.gauge_gpu_cache_usage[engine_idx].set(
|
||||
scheduler_stats.kv_cache_usage)
|
||||
self.gauge_kv_cache_usage[engine_idx].set(
|
||||
scheduler_stats.kv_cache_usage)
|
||||
|
||||
self.counter_gpu_prefix_cache_queries.inc(
|
||||
self.counter_gpu_prefix_cache_queries[engine_idx].inc(
|
||||
scheduler_stats.prefix_cache_stats.queries)
|
||||
self.counter_gpu_prefix_cache_hits.inc(
|
||||
self.counter_gpu_prefix_cache_hits[engine_idx].inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
self.counter_prefix_cache_queries.inc(
|
||||
self.counter_prefix_cache_queries[engine_idx].inc(
|
||||
scheduler_stats.prefix_cache_stats.queries)
|
||||
self.counter_prefix_cache_hits.inc(
|
||||
self.counter_prefix_cache_hits[engine_idx].inc(
|
||||
scheduler_stats.prefix_cache_stats.hits)
|
||||
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
@ -447,42 +519,45 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
|
||||
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
|
||||
self.counter_generation_tokens.inc(
|
||||
self.counter_num_preempted_reqs[engine_idx].inc(
|
||||
iteration_stats.num_preempted_reqs)
|
||||
self.counter_prompt_tokens[engine_idx].inc(
|
||||
iteration_stats.num_prompt_tokens)
|
||||
self.counter_generation_tokens[engine_idx].inc(
|
||||
iteration_stats.num_generation_tokens)
|
||||
self.histogram_iteration_tokens.observe(
|
||||
self.histogram_iteration_tokens[engine_idx].observe(
|
||||
iteration_stats.num_prompt_tokens + \
|
||||
iteration_stats.num_generation_tokens)
|
||||
|
||||
for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter:
|
||||
self.histogram_max_num_generation_tokens_request.observe(
|
||||
max_gen_tokens)
|
||||
self.histogram_max_num_generation_tokens_request[
|
||||
engine_idx].observe(max_gen_tokens)
|
||||
for n_param in iteration_stats.n_params_iter:
|
||||
self.histogram_n_request.observe(n_param)
|
||||
self.histogram_n_request[engine_idx].observe(n_param)
|
||||
for ttft in iteration_stats.time_to_first_tokens_iter:
|
||||
self.histogram_time_to_first_token.observe(ttft)
|
||||
self.histogram_time_to_first_token[engine_idx].observe(ttft)
|
||||
for tpot in iteration_stats.time_per_output_tokens_iter:
|
||||
self.histogram_time_per_output_token.observe(tpot)
|
||||
self.histogram_time_per_output_token[engine_idx].observe(tpot)
|
||||
|
||||
for finished_request in iteration_stats.finished_requests:
|
||||
self.counter_request_success[finished_request.finish_reason].inc()
|
||||
self.histogram_e2e_time_request.observe(
|
||||
self.counter_request_success[
|
||||
finished_request.finish_reason][engine_idx].inc()
|
||||
self.histogram_e2e_time_request[engine_idx].observe(
|
||||
finished_request.e2e_latency)
|
||||
self.histogram_queue_time_request.observe(
|
||||
self.histogram_queue_time_request[engine_idx].observe(
|
||||
finished_request.queued_time)
|
||||
self.histogram_prefill_time_request.observe(
|
||||
self.histogram_prefill_time_request[engine_idx].observe(
|
||||
finished_request.prefill_time)
|
||||
self.histogram_inference_time_request.observe(
|
||||
self.histogram_inference_time_request[engine_idx].observe(
|
||||
finished_request.inference_time)
|
||||
self.histogram_decode_time_request.observe(
|
||||
self.histogram_decode_time_request[engine_idx].observe(
|
||||
finished_request.decode_time)
|
||||
self.histogram_num_prompt_tokens_request.observe(
|
||||
self.histogram_num_prompt_tokens_request[engine_idx].observe(
|
||||
finished_request.num_prompt_tokens)
|
||||
self.histogram_num_generation_tokens_request.observe(
|
||||
self.histogram_num_generation_tokens_request[engine_idx].observe(
|
||||
finished_request.num_generation_tokens)
|
||||
if finished_request.max_tokens_param:
|
||||
self.histogram_max_tokens_request.observe(
|
||||
self.histogram_max_tokens_request[engine_idx].observe(
|
||||
finished_request.max_tokens_param)
|
||||
|
||||
if self.gauge_lora_info is not None:
|
||||
@ -502,6 +577,18 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
||||
|
||||
|
||||
PromMetric = Union[
|
||||
prometheus_client.Gauge,
|
||||
prometheus_client.Counter,
|
||||
prometheus_client.Histogram,
|
||||
]
|
||||
|
||||
|
||||
def make_per_engine(metric: PromMetric, engine_idxs: list[int],
|
||||
model_name: str) -> dict[int, PromMetric]:
|
||||
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
|
||||
|
||||
|
||||
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
|
||||
"""
|
||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||
@ -529,29 +616,79 @@ def build_1_2_5_buckets(max_value: int) -> list[int]:
|
||||
return build_buckets([1, 2, 5], max_value)
|
||||
|
||||
|
||||
def setup_default_loggers(
|
||||
vllm_config: VllmConfig,
|
||||
log_stats: bool,
|
||||
engine_num: int,
|
||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
) -> list[list[StatLoggerBase]]:
|
||||
"""Setup logging and prometheus metrics."""
|
||||
if not log_stats:
|
||||
return []
|
||||
class StatLoggerManager:
|
||||
"""
|
||||
StatLoggerManager:
|
||||
Logging happens at the level of the EngineCore (per scheduler).
|
||||
* DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore.
|
||||
* With Local Logger, just make N copies for N EngineCores.
|
||||
* With Prometheus, we need a single logger with N "labels"
|
||||
|
||||
factories: list[StatLoggerFactory]
|
||||
if custom_stat_loggers is not None:
|
||||
factories = custom_stat_loggers
|
||||
else:
|
||||
factories = [PrometheusStatLogger]
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
factories.append(LoggingStatLogger)
|
||||
This class abstracts away this implementation detail from
|
||||
the AsyncLLM, allowing the AsyncLLM to just call .record()
|
||||
and .log() to a simple interface.
|
||||
"""
|
||||
|
||||
stat_loggers: list[list[StatLoggerBase]] = []
|
||||
for i in range(engine_num):
|
||||
per_engine_stat_loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
per_engine_stat_loggers.append(logger_factory(vllm_config, i))
|
||||
stat_loggers.append(per_engine_stat_loggers)
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_idxs: Optional[list[int]] = None,
|
||||
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
):
|
||||
self.engine_idxs = engine_idxs if engine_idxs else [0]
|
||||
|
||||
return stat_loggers
|
||||
factories: list[StatLoggerFactory]
|
||||
if custom_stat_loggers is not None:
|
||||
factories = custom_stat_loggers
|
||||
else:
|
||||
factories = []
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
factories.append(LoggingStatLogger)
|
||||
|
||||
# engine_idx: StatLogger
|
||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||
prometheus_factory = PrometheusStatLogger
|
||||
for engine_idx in self.engine_idxs:
|
||||
loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
# If we get a custom prometheus logger, use that
|
||||
# instead. This is typically used for the ray case.
|
||||
if (isinstance(logger_factory, type)
|
||||
and issubclass(logger_factory, PrometheusStatLogger)):
|
||||
prometheus_factory = logger_factory
|
||||
continue
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
self.per_engine_logger_dict[engine_idx] = loggers
|
||||
|
||||
# For Prometheus, need to share the metrics between EngineCores.
|
||||
# Each EngineCore's metrics are expressed as a unique label.
|
||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
||||
|
||||
def record(
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: Optional[int] = None,
|
||||
):
|
||||
if engine_idx is None:
|
||||
engine_idx = 0
|
||||
|
||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||
for logger in per_engine_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.prometheus_logger.log_engine_initialized()
|
||||
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log_engine_initialized()
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.metrics.loggers import PrometheusStatLogger
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingProm
|
||||
|
||||
@ -128,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
_spec_decoding_cls = RaySpecDecodingProm
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
super().__init__(vllm_config, engine_index)
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
# No-op on purpose
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user