diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 040b44dc5d2ca..6e58d158c3f4b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -18,12 +18,18 @@ import torch from vllm import LLM from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiKVConnectorStats) from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, - NixlConnectorWorker) + NixlConnectorWorker, NixlKVConnectorStats) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from .utils import create_request, create_scheduler, create_vllm_config @@ -475,6 +481,209 @@ class TestNixlHandshake: # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then # the rest of the tests. +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_kv_connector_stats(dist_init): + """Test that KV transfer stats are properly recorded and retrieved.""" + vllm_config = create_vllm_config() + + # Test worker role in decode server. + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker(vllm_config, + connector.engine_id, + hand_shake_latency=0) + + # Verify that xfer_stats starts empty + initial_stats = connector.get_kv_connector_stats() + assert initial_stats is None + + # Create transfer metadata + request_id = "test_req_for_stats" + metadata = NixlConnectorMetadata() + metadata.add_new_req(request_id=request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }) + connector.bind_connector_metadata(metadata) + + # Start the transfer + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Verify stats are recorded after transfer is complete + max_iterations = 2 + # Clear metadata before start_load_kv to prevent reprocessing same request + connector.bind_connector_metadata(NixlConnectorMetadata()) + for _ in range(max_iterations): + # Need to call start_load_kv to process completed handshakes + connector.start_load_kv(dummy_ctx) + _, done_recving = connector.get_finished(finished_req_ids=set()) + if len(done_recving) > 0 and request_id in done_recving: + break + time.sleep( + 0.1) # Small delay to allow background handshake to complete + else: + assert "Transfer did not complete within expected iterations" + + # Now check that stats were recorded + stats_after_transfer = connector.get_kv_connector_stats() + assert isinstance(stats_after_transfer, NixlKVConnectorStats) + + # Verify stats values are recorded + assert not stats_after_transfer.is_empty() + assert stats_after_transfer.data["num_successful_transfers"] == 1 + + # Verify stats are reset after retrieval + stats_after_reset = connector.get_kv_connector_stats() + assert stats_after_reset is None + + +def test_kv_connector_stats_aggregation(): + """ + Test KV transfer stats aggregation across TP ranks using + KVOutputAggregator (used by MultiprocExecutor). + """ + + # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing + # done in MultiprocExecutor.execute_model + aggregator = KVOutputAggregator(world_size=3) + + # Create stats for multiple workers with different transfer patterns + worker1_stats = NixlKVConnectorStats() + worker2_stats = NixlKVConnectorStats() + worker3_stats = NixlKVConnectorStats() + + # Record different transfers on each worker + # Worker 1: 2 transfers + worker1_stats.record_transfer() + worker1_stats.record_transfer() + + # Worker 2: 1 transfer + worker2_stats.record_transfer() + + # Worker 3: 3 transfers + worker3_stats.record_transfer() + worker3_stats.record_transfer() + worker3_stats.record_transfer() + + # Create ModelRunnerOutput instances for each worker + worker_outputs = [] + for i, worker_stats in enumerate( + [worker1_stats, worker2_stats, worker3_stats]): + output = ModelRunnerOutput( + req_ids=[f"req_{i}"], + req_id_to_index={f"req_{i}": 0}, + sampled_token_ids=[[123]], # dummy token + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=set([f"req_{i}_send"]) + if i < 2 else None, # Workers 0,1 finished sending + finished_recving=set([f"req_{i}_recv"]) + if i > 0 else None, # Workers 1,2 finished receiving + kv_connector_stats=worker_stats, + )) + worker_outputs.append(output) + + # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) + kv_connector_stats = \ + aggregated_output.kv_connector_output.kv_connector_stats + assert isinstance(kv_connector_stats, NixlKVConnectorStats) + # Number of total transfers across all workers. + assert kv_connector_stats.data["num_successful_transfers"] == 6 + + +def test_multi_kv_connector_stats_aggregation(): + """ + Test MultiKVConnectorStats aggregation across TP ranks using + KVOutputAggregator (used by MultiprocExecutor). + """ + + aggregator = KVOutputAggregator(world_size=3) + + from dataclasses import dataclass + + @dataclass + class FooKVConnectorStats(KVConnectorStats): + + def reset(self): + self.data = {"num_foo_transfers": 0} + + def record_transfer(self): + if "num_foo_transfers" not in self.data: + self.data["num_foo_transfers"] = 0 + self.data["num_foo_transfers"] += 1 + + def is_empty(self) -> bool: + return self.data["num_foo_transfers"] == 0 + + def aggregate(self, + other: "FooKVConnectorStats") -> "FooKVConnectorStats": + if not other.is_empty(): + self.data["num_foo_transfers"] += other.data[ + "num_foo_transfers"] + return self + + def make_multi_stats(nixl_count: int, + foo_count: int) -> MultiKVConnectorStats: + data: dict[str, KVConnectorStats] = {} + if nixl_count > 0: + nixl_stats = NixlKVConnectorStats() + for _ in range(nixl_count): + nixl_stats.record_transfer() + data["NixlConnector"] = nixl_stats + if foo_count > 0: + foo_stats = FooKVConnectorStats() + for _ in range(foo_count): + foo_stats.record_transfer() + data["FooConnector"] = foo_stats + return MultiKVConnectorStats(data=data) + + # Create heterogeneous stats across 3 workers + worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo) + + worker_outputs: list[ModelRunnerOutput] = [] + for i, (nixl, foo) in enumerate(worker_patterns): + stats = make_multi_stats(nixl, foo) + output = ModelRunnerOutput( + req_ids=[f"req_{i}"], + req_id_to_index={f"req_{i}": 0}, + sampled_token_ids=[[123]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + finished_sending=set([f"req_{i}_send"]) if i < 2 else None, + finished_recving=set([f"req_{i}_recv"]) if i > 0 else None, + kv_connector_stats=stats, + ), + ) + worker_outputs.append(output) + + aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) + kv_connector_stats = \ + aggregated_output.kv_connector_output.kv_connector_stats + assert isinstance(kv_connector_stats, MultiKVConnectorStats) + + # Validate per-connector totals across workers + assert kv_connector_stats["NixlConnector"].data[ + "num_successful_transfers"] == 5 + assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 + + @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f4dc248a12794..911d77ba36fa0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -129,7 +129,7 @@ class KVOutputAggregator: def aggregate(self, outputs: list[ModelRunnerOutput], output_rank: int = 0) -> ModelRunnerOutput: - # aggregate kv_connector_output from all workers + # Aggregate kv_connector_output from all workers def update_finished_set(req_ids: Optional[set[str]], remaining_count_dict: dict[str, int], @@ -142,8 +142,9 @@ class KVOutputAggregator: finished_sending = set[str]() finished_recving = set[str]() - for output in outputs: - output = output.kv_connector_output + aggregated_kv_connector_stats = None + for model_runner_output in outputs: + output = model_runner_output.kv_connector_output if not output: continue update_finished_set(output.finished_sending, @@ -151,12 +152,26 @@ class KVOutputAggregator: update_finished_set(output.finished_recving, self._recv_remaining_count, finished_recving) + # Aggregate kv_connector_stats from all workers. + if aggregated_kv_connector_stats is None: + # Use the first worker's kv_connector_stats as accumulator. + aggregated_kv_connector_stats = output.kv_connector_stats + elif kv_connector_stats := output.kv_connector_stats: + if aggregated_kv_connector_stats is None: + aggregated_kv_connector_stats = kv_connector_stats + else: + assert isinstance(aggregated_kv_connector_stats, + type(kv_connector_stats)) + aggregated_kv_connector_stats = \ + aggregated_kv_connector_stats.aggregate(kv_connector_stats) + # select output of the worker specified by output_rank output = outputs[output_rank] output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending or None, finished_recving=finished_recving or None, + kv_connector_stats=aggregated_kv_connector_stats or None, ) return output diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 70c07eac6304b..184d0a62f2c30 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -49,6 +49,8 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC): """ return None + def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]: + """ + Get the KV connector stats collected during the last interval. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -365,4 +373,16 @@ class KVConnectorBase_V1(ABC): int: expected sending or receiving completion count. """ - return None \ No newline at end of file + return None + + @classmethod + def build_kv_connector_stats( + cls, + data: Optional[dict[str, + Any]] = None) -> Optional["KVConnectorStats"]: + """ + KVConnectorStats resolution method. This method allows dynamically + registered connectors to return their own KVConnectorStats object, + which can implement custom aggregation logic on the data dict. + """ + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py new file mode 100644 index 0000000000000..e40007230ba45 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from typing import Any, Optional, Union + +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_transfer_state import ( + has_kv_transfer_group) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class KVConnectorStats: + """ + Base class for KV Connector Stats, a container for transfer performance + metrics or otherwise important telemetry from the connector. + All sub-classes need to be serializable as stats are sent from worker to + logger process. + """ + data: dict[str, Any] = field(default_factory=dict) + + def reset(self): + """Reset the stats, clear the state.""" + raise NotImplementedError + + def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": + """ + Aggregate stats with another `KVConnectorStats` object. + """ + raise NotImplementedError + + def reduce(self) -> dict[str, Union[int, float]]: + """ + Reduce the observations collected during a time interval to one or + more representative values (eg avg/median/sum of the series). + This is meant to be called by the logger to produce a summary of the + stats for the last time interval. + """ + raise NotImplementedError + + def is_empty(self) -> bool: + """Return True if the stats are empty.""" + raise NotImplementedError + + +class KVConnectorLogging: + + def __init__(self, kv_tranfer_config: KVTransferConfig): + # This should be called on frontend process. + assert not has_kv_transfer_group() + # Instantiate the connector's stats class. + if kv_tranfer_config and kv_tranfer_config.kv_connector: + self.connector_cls = KVConnectorFactory.get_connector_class( + kv_tranfer_config) + self.reset() + + def reset(self): + self.transfer_stats_accumulator: Optional[KVConnectorStats] = None + + def observe(self, transfer_stats_data: dict[str, Any]): + # Should not be called when a KVConnector is not configured. + assert self.connector_cls is not None + # Called periodically when connector syncs with the scheduler. + # Note that this is not the same as the logging interval. + # We expect transfer_stats_data to be aggregated across all workers and + # consist of observations from a single connector or a MultiConnector. + transfer_stats = self.connector_cls.build_kv_connector_stats( + transfer_stats_data) + if transfer_stats is None: + logger.warning_once( + "The connector %s is collecting stats but " + "does not implement the " + "`build_kv_connector_stats` method. " + "Stats will not be logged.", self.connector_cls) + return + + if self.transfer_stats_accumulator is None: + self.transfer_stats_accumulator = transfer_stats + else: + # Accumulate last interval stats. + self.transfer_stats_accumulator = \ + self.transfer_stats_accumulator.aggregate(transfer_stats) + + def log(self, log_fn=logger.info): + """Log transfer metrics periodically, similar to throughput logging""" + if (self.transfer_stats_accumulator + and not self.transfer_stats_accumulator.is_empty()): + # Produce a single cumulative stats object for the last time + # interval from the recorded observations. + xfer_metrics = self.transfer_stats_accumulator.reduce() + xfer_metrics_str = ", ".join(f"{k}={v}" + for k, v in xfer_metrics.items()) + log_fn("KV Transfer metrics: %s", xfer_metrics_str) + + # Reset metrics for next interval + self.reset() \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 616d158d67670..6836a71e58d62 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -9,19 +9,21 @@ import torch from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata + from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request logger = init_logger(__name__) @@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata): extra_async_saves: Optional[dict[str, int]] = None +@dataclass +class MultiKVConnectorStats(KVConnectorStats): + """ + Maintain a dict of KVConnectorStats objects, one for each connector. + This is used to aggregate the stats from all connectors separately. + """ + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + for connector_id, stats in other.data.items(): + if connector_id not in self.data: + self[connector_id] = stats + else: + assert isinstance(stats, type(self.data[connector_id])) + self[connector_id] = self[connector_id].aggregate(stats) + return self + + def reset(self): + for stats in self.data.values(): + stats.reset() + + def reduce(self) -> dict[str, Any]: + # TODO (NickLucche) Adjust for logging on separate lines + return { + connector_id: stats.reduce() + for connector_id, stats in self.data.items() + } + + def is_empty(self) -> bool: + return all(stats.is_empty() for stats in self.data.values()) + + def __getitem__(self, connector_id: str) -> KVConnectorStats: + return self.data[connector_id] + + def __setitem__(self, connector_id: str, stats: KVConnectorStats): + self.data[connector_id] = stats + + class MultiConnector(KVConnectorBase_V1): """ A wrapper for using multiple KVConnectors at the same time. @@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._connectors: list[KVConnectorBase_V1] = [] + self._ktc_kv_transfer_config = [] ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "connectors") assert ktcs is not None @@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1): **ktc, engine_id=engine_id) self._connectors.append( KVConnectorFactory.create_connector(temp_config, role)) + self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to # load the request from (if any). @@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1): return async_saves > 0, kv_txfer_params - def take_events(self) -> Iterable[KVCacheEvent]: + def take_events(self) -> Iterable["KVCacheEvent"]: for c in self._connectors: yield from c.take_events() @@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1): f"({', '.join(layouts) })." f"All connectors must use the same layout.") return next(iter(layouts), None) + + @classmethod + def build_kv_connector_stats( + cls, + data: Optional[dict[str, + Any]] = None) -> Optional[KVConnectorStats]: + return MultiKVConnectorStats(data=data) if data is not None \ + else MultiKVConnectorStats() + + def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]: + # Group connector stats by connector type. + stats_by_connector: Optional[MultiKVConnectorStats] = None + for c in self._connectors: + stats = c.get_kv_connector_stats() + if stats is None: + continue + if stats_by_connector is None: + # Lazy init to allow optional return value. + stats_by_connector = MultiKVConnectorStats() + stats_by_connector[c.__class__.__name__] = stats + return stats_by_connector diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1ff1407aeb99b..ff62f60e5a42c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy import logging import math import queue @@ -11,7 +12,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec import numpy as np @@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform from vllm.utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import RequestStatus if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1): assert self.connector_worker is not None return self.connector_worker.get_finished() + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + assert self.connector_worker is not None + return self.connector_worker.get_kv_connector_stats() + + @classmethod + def build_kv_connector_stats( + cls, + data: Optional[dict[str, + Any]] = None) -> Optional[KVConnectorStats]: + return NixlKVConnectorStats(data=data) if data is not None \ + else NixlKVConnectorStats() + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None @@ -377,6 +391,7 @@ class NixlConnectorScheduler: Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. """ + from vllm.v1.request import RequestStatus params = request.kv_transfer_params logger.debug( @@ -550,6 +565,7 @@ class NixlConnectorWorker: # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + self.xfer_stats = NixlKVConnectorStats() def __del__(self): """Cleanup background threads on destruction.""" @@ -1097,6 +1113,8 @@ class NixlConnectorWorker: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": self.nixl_wrapper.release_xfer_handle(handle) + # TODO (NickLucche) Get from NIXL telemetry once integrated + self.xfer_stats.record_transfer() elif xfer_state == "PROC": in_progress = True continue @@ -1248,7 +1266,6 @@ class NixlConnectorWorker: self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - # TODO (NickLucche) surface xfer elapsed time self._recving_transfers[request_id].append( (handle, time.perf_counter())) @@ -1300,6 +1317,15 @@ class NixlConnectorWorker: block_len = self.block_len return block_len + def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + """ + Get the KV transfer stats for the connector. + """ + # Clear stats for next iteration + if not self.xfer_stats.is_empty(): + return self.xfer_stats.clone_and_reset() + return None + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: @@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: finally: if ctx is not None: ctx.destroy(linger=0) + + +@dataclass +class NixlKVConnectorStats(KVConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if "num_successful_transfers" not in self.data: + self.data["num_successful_transfers"] = 0 + + def reset(self): + self.data = {"num_successful_transfers": 0} + + def record_transfer(self): + # TODO: record actual transfer stats when available + self.data["num_successful_transfers"] += 1 + + def clone_and_reset(self) -> "NixlKVConnectorStats": + old = copy.copy(self) + self.reset() + return old + + def is_empty(self) -> bool: + return self.data["num_successful_transfers"] == 0 + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + if not other.is_empty(): + self.data["num_successful_transfers"] += other.data[ + "num_successful_transfers"] + return self + + def reduce(self) -> dict[str, Union[int, float]]: + # TODO: reduce stats to a single value, calculate latency/throughput + return { + "num_successful_transfers": self.data["num_successful_transfers"] + } \ No newline at end of file diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 85ca858ad7bd6..b08898d253cab 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats = (kv_connector_output.kv_connector_stats + if kv_connector_output else None) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface): finished_requests=finished_set) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats)) is not None: + if (stats := self.make_stats(spec_decoding_stats, + kv_connector_stats)) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface): def make_stats( self, spec_decoding_stats: Optional[SpecDecodingStats] = None, + kv_connector_stats: Optional[KVConnectorStats] = None, ) -> Optional[SchedulerStats]: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats( - num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - ) + return SchedulerStats(num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted + for req in self.running), + kv_connector_stats=kv_connector_stats.data + if kv_connector_stats else None) def make_spec_decoding_stats( self, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index b30036a6f8e80..f0076b2d81dbf 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,6 +9,8 @@ from typing import Callable, Optional, Union import prometheus_client from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorLogging) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason @@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase): # TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() self.spec_decoding_logging = SpecDecodingLogging() + kv_tranfer_config = self.vllm_config.kv_transfer_config + self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 @@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase): if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_logging.observe( scheduler_stats.spec_decoding_stats) - + if kv_connector_stats := scheduler_stats.kv_connector_stats: + self.kv_transfer_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats def log(self): @@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase): self.prefix_caching_metrics.hit_rate * 100, ) self.spec_decoding_logging.log(log_fn=log_fn) + self.kv_transfer_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index e6c344d193df2..0eff557336bc0 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -3,7 +3,7 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -43,6 +43,7 @@ class SchedulerStats: default_factory=PrefixCacheStats) spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats: Optional[dict[str, Any]] = None num_corrupted_reqs: int = 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 1b2da8addb19e..e6cc6019b1728 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -3,10 +3,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional import torch +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) + class LogprobsLists(NamedTuple): @@ -77,6 +81,11 @@ class KVConnectorOutput: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + kv_connector_stats: Optional["KVConnectorStats"] = None + + def is_empty(self): + return (not self.finished_sending and not self.finished_recving + and not self.kv_connector_stats) # ModelRunnerOutput is serialized and sent to the scheduler process. diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 3eb9f26e9f5b6..016a90c196ba3 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, @@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin: output.finished_sending, output.finished_recving = ( kv_connector.get_finished(scheduler_output.finished_req_ids)) - kv_connector.clear_connector_metadata() + output.kv_connector_stats = KVConnectorModelRunnerMixin.\ + get_kv_connector_stats() + + @staticmethod + def get_kv_connector_stats() -> Optional[KVConnectorStats]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_kv_connector_stats() + return None