[KVConnector] Add metrics to Prometheus-Grafana dashboard (#26811)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-29 19:44:49 +01:00 committed by GitHub
parent 5b0448104f
commit accb8fab07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 365 additions and 29 deletions

View File

@ -50,7 +50,12 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@ -471,3 +476,18 @@ class KVConnectorBase_V1(ABC):
which can implement custom aggregation logic on the data dict.
"""
return None
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None

View File

@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeAlias, TypeVar
from vllm.config.kv_transfer import KVTransferConfig
from prometheus_client import Counter, Gauge, Histogram
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger
PromMetric: TypeAlias = Gauge | Counter | Histogram
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
logger = init_logger(__name__)
@ -102,3 +107,83 @@ class KVConnectorLogging:
# Reset metrics for next interval
self.reset()
class KVConnectorPromMetrics:
"""
A base class for per-connector Prometheus metric registration
and recording.
"""
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._kv_transfer_config = vllm_config.kv_transfer_config
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> PromMetric:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
"""
raise NotImplementedError
class KVConnectorPrometheus:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
KVConnectorBase.build_prom_metrics().
"""
_gauge_cls = Gauge
_counter_cls = Counter
_histogram_cls = Histogram
def __init__(
self,
vllm_config: VllmConfig,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Gauge: self._gauge_cls,
Counter: self._counter_cls,
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
if self.prom_metrics is None:
return
self.prom_metrics.observe(transfer_stats_data, engine_idx)

View File

@ -9,13 +9,19 @@ import torch
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
@ -72,6 +78,27 @@ class MultiKVConnectorStats(KVConnectorStats):
self.data[connector_id] = stats
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
prom_metrics: dict[str, KVConnectorPromMetrics],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
self._prom_metrics = prom_metrics
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for connector_id, stats_data in transfer_stats_data.items():
assert connector_id in self._prom_metrics, (
f"{connector_id} is not contained in the list of registered connectors "
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role)
)
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to
@ -109,6 +130,32 @@ class MultiConnector(KVConnectorBase_V1):
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
return ret
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
@ -295,18 +342,12 @@ class MultiConnector(KVConnectorBase_V1):
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
temp_vllm_config
temp_config
)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
@ -372,3 +413,28 @@ class MultiConnector(KVConnectorBase_V1):
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
connector_prom = connector_cls.build_prom_metrics(
temp_config, metric_types, labelnames, per_engine_labelvalues
)
if connector_prom is not None:
prom_metrics[connector_cls.__name__] = connector_prom
return MultiKVConnectorPromMetrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
prom_metrics,
)

View File

@ -30,7 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -254,6 +259,18 @@ class NixlConnector(KVConnectorBase_V1):
else NixlKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
@ -1960,3 +1977,125 @@ class NixlKVConnectorStats(KVConnectorStats):
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = self.make_per_engine(
nixl_histogram_bytes_transferred
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = self.make_per_engine(
nixl_histogram_num_descriptors
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = self.make_per_engine(
counter_nixl_num_failed_transfers
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = self.make_per_engine(
counter_nixl_num_failed_notifications
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
],
["num_failed_transfers", "num_failed_notifications"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)

View File

@ -11,7 +11,10 @@ from prometheus_client import Counter, Gauge, Histogram
import vllm.envs as envs
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
KVConnectorPrometheus,
)
from vllm.logger import init_logger
from vllm.plugins import load_plugins_by_group
from vllm.v1.engine import FinishReason
@ -339,6 +342,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
_counter_cls = Counter
_histogram_cls = Histogram
_spec_decoding_cls = SpecDecodingProm
_kv_connector_cls = KVConnectorPrometheus
def __init__(
self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None
@ -358,12 +362,15 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
spec_decode_labelvalues: dict[int, list[str]] = {
per_engine_labelvalues: dict[int, list[str]] = {
idx: [model_name, str(idx)] for idx in engine_indexes
}
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, spec_decode_labelvalues
vllm_config.speculative_config, labelnames, per_engine_labelvalues
)
self.kv_connector_prom = self._kv_connector_cls(
vllm_config, labelnames, per_engine_labelvalues
)
#
@ -962,6 +969,11 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.spec_decoding_stats, engine_idx
)
if scheduler_stats.kv_connector_stats is not None:
self.kv_connector_prom.observe(
scheduler_stats.kv_connector_stats, engine_idx
)
if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
@ -141,6 +142,18 @@ class RaySpecDecodingProm(SpecDecodingProm):
_counter_cls = RayCounterWrapper
class RayKVConnectorPrometheus(KVConnectorPrometheus):
"""
RayKVConnectorPrometheus is used by RayMetrics to log Ray
metrics. Provides the same metrics as KV connectors but
uses Ray's util.metrics library.
"""
_gauge_cls = RayGaugeWrapper
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
@ -148,6 +161,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
_kv_connector_cls = RayKVConnectorPrometheus
@staticmethod
def _unregister_vllm_metrics():