diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2562eb9ce70e4..2ed0fe592e373 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -50,7 +50,12 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, + ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -471,3 +476,18 @@ class KVConnectorBase_V1(ABC): which can implement custom aggregation logic on the data dict. """ return None + + @classmethod + def build_prom_metrics( + cls, + vllm_config: "VllmConfig", + metric_types: dict[type["PromMetric"], type["PromMetricT"]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> Optional["KVConnectorPromMetrics"]: + """ + Create a KVConnectorPromMetrics subclass which should register + per-connector Prometheus metrics and implement observe() to + expose connector transfer stats via Prometheus. + """ + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index 21002fe572c52..d6ea4f1ab4cfc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import Any +from typing import Any, TypeAlias, TypeVar -from vllm.config.kv_transfer import KVTransferConfig +from prometheus_client import Counter, Gauge, Histogram + +from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group from vllm.logger import init_logger +PromMetric: TypeAlias = Gauge | Counter | Histogram +PromMetricT = TypeVar("PromMetricT", bound=PromMetric) + logger = init_logger(__name__) @@ -102,3 +107,83 @@ class KVConnectorLogging: # Reset metrics for next interval self.reset() + + +class KVConnectorPromMetrics: + """ + A base class for per-connector Prometheus metric registration + and recording. + """ + + def __init__( + self, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + self._kv_transfer_config = vllm_config.kv_transfer_config + self._gauge_cls = metric_types[Gauge] + self._counter_cls = metric_types[Counter] + self._histogram_cls = metric_types[Histogram] + self._labelnames = labelnames + self._per_engine_labelvalues = per_engine_labelvalues + + def make_per_engine(self, metric: PromMetric) -> PromMetric: + """ + Create a per-engine child of a prometheus_client.Metric with + the appropriate labels set. The parent metric must be created + using the labelnames list. + """ + return { + idx: metric.labels(*labelvalues) + for idx, labelvalues in self._per_engine_labelvalues.items() + } + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + """ + Record the supplied transfer statistics to Prometheus metrics. These + statistics are engine-specific, and should be recorded to a metric + with the appropriate 'engine' label. These metric instances can be + created using the make_per_engine() helper method. + """ + raise NotImplementedError + + +class KVConnectorPrometheus: + """ + Support for registering per-connector Prometheus metrics, and + recording transfer statistics to those metrics. Uses + KVConnectorBase.build_prom_metrics(). + """ + + _gauge_cls = Gauge + _counter_cls = Counter + _histogram_cls = Histogram + + def __init__( + self, + vllm_config: VllmConfig, + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + self.prom_metrics: KVConnectorPromMetrics | None = None + kv_transfer_config = vllm_config.kv_transfer_config + if kv_transfer_config and kv_transfer_config.kv_connector: + connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config) + metric_types = { + Gauge: self._gauge_cls, + Counter: self._counter_cls, + Histogram: self._histogram_cls, + } + self.prom_metrics = connector_cls.build_prom_metrics( + vllm_config, + metric_types, + labelnames, + per_engine_labelvalues, + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + if self.prom_metrics is None: + return + self.prom_metrics.observe(transfer_stats_data, engine_idx) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index c1a2ac012415a..d56f30bd11e5b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -9,13 +9,19 @@ import torch from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput @@ -72,6 +78,27 @@ class MultiKVConnectorStats(KVConnectorStats): self.data[connector_id] = stats +class MultiKVConnectorPromMetrics(KVConnectorPromMetrics): + def __init__( + self, + vllm_config: "VllmConfig", + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + prom_metrics: dict[str, KVConnectorPromMetrics], + ): + super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) + self._prom_metrics = prom_metrics + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + for connector_id, stats_data in transfer_stats_data.items(): + assert connector_id in self._prom_metrics, ( + f"{connector_id} is not contained in the list of registered connectors " + f"with Prometheus metrics support: {self._prom_metrics.keys()}" + ) + self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) + + class MultiConnector(KVConnectorBase_V1): """ A wrapper for using multiple KVConnectors at the same time. @@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) + self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] - ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors") - assert ktcs is not None - for ktc in ktcs: - temp_config = copy.copy(vllm_config) - engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id) - temp_config.kv_transfer_config = KVTransferConfig( - **ktc, engine_id=engine_id - ) - self._connectors.append( - KVConnectorFactory.create_connector(temp_config, role) - ) + for connector_cls, temp_config in self._get_connector_classes_and_configs( + vllm_config + ): + self._connectors.append(connector_cls(temp_config, role)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to @@ -109,6 +130,32 @@ class MultiConnector(KVConnectorBase_V1): # Propagated from scheduler to worker side via the connector metadata. self._extra_async_saves: dict[str, int] = {} + @classmethod + def _get_connector_classes_and_configs( + cls, vllm_config: "VllmConfig" + ) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]: + assert vllm_config.kv_transfer_config is not None + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors" + ) + assert ktcs is not None + ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = [] + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) + temp_config.kv_transfer_config = KVTransferConfig( + **ktc, engine_id=engine_id + ) + ret.append( + ( + KVConnectorFactory.get_connector_class( + temp_config.kv_transfer_config + ), + temp_config, + ) + ) + return ret + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) @@ -295,18 +342,12 @@ class MultiConnector(KVConnectorBase_V1): None if the connector does not require a specific layout. """ assert vllm_config.kv_transfer_config is not None - ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors" - ) - assert ktcs is not None layouts: set[str] = set() - temp_vllm_config = copy.copy(vllm_config) - for ktc in ktcs: - kv_transfer_config = KVTransferConfig(**ktc) - temp_vllm_config.kv_transfer_config = kv_transfer_config - connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config) + for connector_cls, temp_config in cls._get_connector_classes_and_configs( + vllm_config + ): required_kvcache_layout = connector_cls.get_required_kvcache_layout( - temp_vllm_config + temp_config ) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) @@ -372,3 +413,28 @@ class MultiConnector(KVConnectorBase_V1): stats_by_connector = MultiKVConnectorStats() stats_by_connector[c.__class__.__name__] = stats return stats_by_connector + + @classmethod + def build_prom_metrics( + cls, + vllm_config: "VllmConfig", + metric_types: dict[type["PromMetric"], type["PromMetricT"]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> KVConnectorPromMetrics: + prom_metrics: dict[str, KVConnectorPromMetrics] = {} + for connector_cls, temp_config in cls._get_connector_classes_and_configs( + vllm_config + ): + connector_prom = connector_cls.build_prom_metrics( + temp_config, metric_types, labelnames, per_engine_labelvalues + ) + if connector_prom is not None: + prom_metrics[connector_cls.__name__] = connector_prom + return MultiKVConnectorPromMetrics( + vllm_config, + metric_types, + labelnames, + per_engine_labelvalues, + prom_metrics, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 72fcb5cd5bb78..275a8c734058b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -30,7 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -254,6 +259,18 @@ class NixlConnector(KVConnectorBase_V1): else NixlKVConnectorStats() ) + @classmethod + def build_prom_metrics( + cls, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> KVConnectorPromMetrics: + return NixlPromMetrics( + vllm_config, metric_types, labelnames, per_engine_labelvalues + ) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) @@ -1960,3 +1977,125 @@ class NixlKVConnectorStats(KVConnectorStats): @property def num_successful_transfers(self) -> int: return len(self.data["transfer_duration"]) + + +class NixlPromMetrics(KVConnectorPromMetrics): + def __init__( + self, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) + + buckets = [ + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.2, + 0.3, + 0.5, + 0.75, + 1.0, + 5.0, + ] + nixl_histogram_xfer_time = self._histogram_cls( + name="vllm:nixl_xfer_time_seconds", + documentation="Histogram of transfer duration for NIXL KV Cache transfers.", + buckets=buckets[1:], + labelnames=labelnames, + ) + self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time) + nixl_histogram_post_time = self._histogram_cls( + name="vllm:nixl_post_time_seconds", + documentation="Histogram of transfer post time for NIXL KV" + " Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time) + # uniform 2kb to 16gb range + buckets = [2 ** (10 + i) for i in range(1, 25, 2)] + nixl_histogram_bytes_transferred = self._histogram_cls( + name="vllm:nixl_bytes_transferred", + documentation="Histogram of bytes transferred per NIXL KV Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_bytes_transferred = self.make_per_engine( + nixl_histogram_bytes_transferred + ) + buckets = [ + 10, + 20, + 30, + 50, + 75, + 100, + 200, + 400, + 1000, + 2000, + 4000, + 10000, + 20000, + 50000, + ] + nixl_histogram_num_descriptors = self._histogram_cls( + name="vllm:nixl_num_descriptors", + documentation="Histogram of number of descriptors per NIXL" + " KV Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_num_descriptors = self.make_per_engine( + nixl_histogram_num_descriptors + ) + counter_nixl_num_failed_transfers = self._counter_cls( + name="vllm:nixl_num_failed_transfers", + documentation="Number of failed NIXL KV Cache transfers.", + labelnames=labelnames, + ) + self.counter_nixl_num_failed_transfers = self.make_per_engine( + counter_nixl_num_failed_transfers + ) + counter_nixl_num_failed_notifications = self._counter_cls( + name="vllm:nixl_num_failed_notifications", + documentation="Number of failed NIXL KV Cache notifications.", + labelnames=labelnames, + ) + self.counter_nixl_num_failed_notifications = self.make_per_engine( + counter_nixl_num_failed_notifications + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + for prom_obj, list_item_key in zip( + [ + self.nixl_histogram_xfer_time, + self.nixl_histogram_post_time, + self.nixl_histogram_bytes_transferred, + self.nixl_histogram_num_descriptors, + ], + [ + "transfer_duration", + "post_duration", + "bytes_transferred", + "num_descriptors", + ], + ): + for list_item in transfer_stats_data[list_item_key]: + prom_obj[engine_idx].observe(list_item) + for counter_obj, counter_item_key in zip( + [ + self.counter_nixl_num_failed_transfers, + self.counter_nixl_num_failed_notifications, + ], + ["num_failed_transfers", "num_failed_notifications"], + ): + for list_item in transfer_stats_data[counter_item_key]: + counter_obj[engine_idx].inc(list_item) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 055da5d856b25..3772f07066a12 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -11,7 +11,10 @@ from prometheus_client import Counter, Gauge, Histogram import vllm.envs as envs from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorLogging, + KVConnectorPrometheus, +) from vllm.logger import init_logger from vllm.plugins import load_plugins_by_group from vllm.v1.engine import FinishReason @@ -339,6 +342,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase): _counter_cls = Counter _histogram_cls = Histogram _spec_decoding_cls = SpecDecodingProm + _kv_connector_cls = KVConnectorPrometheus def __init__( self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None @@ -358,12 +362,15 @@ class PrometheusStatLogger(AggregateStatLoggerBase): model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - spec_decode_labelvalues: dict[int, list[str]] = { + per_engine_labelvalues: dict[int, list[str]] = { idx: [model_name, str(idx)] for idx in engine_indexes } self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, spec_decode_labelvalues + vllm_config.speculative_config, labelnames, per_engine_labelvalues + ) + self.kv_connector_prom = self._kv_connector_cls( + vllm_config, labelnames, per_engine_labelvalues ) # @@ -962,6 +969,11 @@ class PrometheusStatLogger(AggregateStatLoggerBase): scheduler_stats.spec_decoding_stats, engine_idx ) + if scheduler_stats.kv_connector_stats is not None: + self.kv_connector_prom.observe( + scheduler_stats.kv_connector_stats, engine_idx + ) + if mm_cache_stats is not None: self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index b845852a0c0d5..a319ffb1d2573 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm @@ -141,6 +142,18 @@ class RaySpecDecodingProm(SpecDecodingProm): _counter_cls = RayCounterWrapper +class RayKVConnectorPrometheus(KVConnectorPrometheus): + """ + RayKVConnectorPrometheus is used by RayMetrics to log Ray + metrics. Provides the same metrics as KV connectors but + uses Ray's util.metrics library. + """ + + _gauge_cls = RayGaugeWrapper + _counter_cls = RayCounterWrapper + _histogram_cls = RayHistogramWrapper + + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" @@ -148,6 +161,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger): _counter_cls = RayCounterWrapper _histogram_cls = RayHistogramWrapper _spec_decoding_cls = RaySpecDecodingProm + _kv_connector_cls = RayKVConnectorPrometheus @staticmethod def _unregister_vllm_metrics():