[P/D][Nixl] Introduce KVTransferMetrics and aggregation strategy (#22188)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-19 13:09:14 +02:00 committed by GitHub
parent 058525b997
commit a3d087adec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 525 additions and 25 deletions

View File

@ -18,12 +18,18 @@ import torch
from vllm import LLM from vllm import LLM
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker) NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from .utils import create_request, create_scheduler, create_vllm_config from .utils import create_request, create_scheduler, create_vllm_config
@ -475,6 +481,209 @@ class TestNixlHandshake:
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then # we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests. # the rest of the tests.
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_kv_connector_stats(dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
# Verify that xfer_stats starts empty
initial_stats = connector.get_kv_connector_stats()
assert initial_stats is None
# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)
# Start the transfer
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
# Verify stats are recorded after transfer is complete
max_iterations = 2
# Clear metadata before start_load_kv to prevent reprocessing same request
connector.bind_connector_metadata(NixlConnectorMetadata())
for _ in range(max_iterations):
# Need to call start_load_kv to process completed handshakes
connector.start_load_kv(dummy_ctx)
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0 and request_id in done_recving:
break
time.sleep(
0.1) # Small delay to allow background handshake to complete
else:
assert "Transfer did not complete within expected iterations"
# Now check that stats were recorded
stats_after_transfer = connector.get_kv_connector_stats()
assert isinstance(stats_after_transfer, NixlKVConnectorStats)
# Verify stats values are recorded
assert not stats_after_transfer.is_empty()
assert stats_after_transfer.data["num_successful_transfers"] == 1
# Verify stats are reset after retrieval
stats_after_reset = connector.get_kv_connector_stats()
assert stats_after_reset is None
def test_kv_connector_stats_aggregation():
"""
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3)
# Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats()
worker2_stats = NixlKVConnectorStats()
worker3_stats = NixlKVConnectorStats()
# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats.record_transfer()
worker1_stats.record_transfer()
# Worker 2: 1 transfer
worker2_stats.record_transfer()
# Worker 3: 3 transfers
worker3_stats.record_transfer()
worker3_stats.record_transfer()
worker3_stats.record_transfer()
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_stats in enumerate(
[worker1_stats, worker2_stats, worker3_stats]):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2 else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0 else None, # Workers 1,2 finished receiving
kv_connector_stats=worker_stats,
))
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers.
assert kv_connector_stats.data["num_successful_transfers"] == 6
def test_multi_kv_connector_stats_aggregation():
"""
Test MultiKVConnectorStats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator = KVOutputAggregator(world_size=3)
from dataclasses import dataclass
@dataclass
class FooKVConnectorStats(KVConnectorStats):
def reset(self):
self.data = {"num_foo_transfers": 0}
def record_transfer(self):
if "num_foo_transfers" not in self.data:
self.data["num_foo_transfers"] = 0
self.data["num_foo_transfers"] += 1
def is_empty(self) -> bool:
return self.data["num_foo_transfers"] == 0
def aggregate(self,
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
if not other.is_empty():
self.data["num_foo_transfers"] += other.data[
"num_foo_transfers"]
return self
def make_multi_stats(nixl_count: int,
foo_count: int) -> MultiKVConnectorStats:
data: dict[str, KVConnectorStats] = {}
if nixl_count > 0:
nixl_stats = NixlKVConnectorStats()
for _ in range(nixl_count):
nixl_stats.record_transfer()
data["NixlConnector"] = nixl_stats
if foo_count > 0:
foo_stats = FooKVConnectorStats()
for _ in range(foo_count):
foo_stats.record_transfer()
data["FooConnector"] = foo_stats
return MultiKVConnectorStats(data=data)
# Create heterogeneous stats across 3 workers
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
worker_outputs: list[ModelRunnerOutput] = []
for i, (nixl, foo) in enumerate(worker_patterns):
stats = make_multi_stats(nixl, foo)
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
kv_connector_stats=stats,
),
)
worker_outputs.append(output)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
# Validate per-connector totals across workers
assert kv_connector_stats["NixlConnector"].data[
"num_successful_transfers"] == 5
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",

View File

@ -129,7 +129,7 @@ class KVOutputAggregator:
def aggregate(self, def aggregate(self,
outputs: list[ModelRunnerOutput], outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput: output_rank: int = 0) -> ModelRunnerOutput:
# aggregate kv_connector_output from all workers # Aggregate kv_connector_output from all workers
def update_finished_set(req_ids: Optional[set[str]], def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int], remaining_count_dict: dict[str, int],
@ -142,8 +142,9 @@ class KVOutputAggregator:
finished_sending = set[str]() finished_sending = set[str]()
finished_recving = set[str]() finished_recving = set[str]()
for output in outputs: aggregated_kv_connector_stats = None
output = output.kv_connector_output for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output: if not output:
continue continue
update_finished_set(output.finished_sending, update_finished_set(output.finished_sending,
@ -151,12 +152,26 @@ class KVOutputAggregator:
update_finished_set(output.finished_recving, update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving) self._recv_remaining_count, finished_recving)
# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(aggregated_kv_connector_stats,
type(kv_connector_stats))
aggregated_kv_connector_stats = \
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[output_rank] output = outputs[output_rank]
output.kv_connector_output = KVConnectorOutput( output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None, finished_sending=finished_sending or None,
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
) )
return output return output

View File

@ -49,6 +49,8 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
@ -366,3 +374,15 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None

View File

@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class KVConnectorStats:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data: dict[str, Any] = field(default_factory=dict)
def reset(self):
"""Reset the stats, clear the state."""
raise NotImplementedError
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise NotImplementedError
def reduce(self) -> dict[str, Union[int, float]]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise NotImplementedError
def is_empty(self) -> bool:
"""Return True if the stats are empty."""
raise NotImplementedError
class KVConnectorLogging:
def __init__(self, kv_tranfer_config: KVTransferConfig):
# This should be called on frontend process.
assert not has_kv_transfer_group()
# Instantiate the connector's stats class.
if kv_tranfer_config and kv_tranfer_config.kv_connector:
self.connector_cls = KVConnectorFactory.get_connector_class(
kv_tranfer_config)
self.reset()
def reset(self):
self.transfer_stats_accumulator: Optional[KVConnectorStats] = None
def observe(self, transfer_stats_data: dict[str, Any]):
# Should not be called when a KVConnector is not configured.
assert self.connector_cls is not None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats = self.connector_cls.build_kv_connector_stats(
transfer_stats_data)
if transfer_stats is None:
logger.warning_once(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged.", self.connector_cls)
return
if self.transfer_stats_accumulator is None:
self.transfer_stats_accumulator = transfer_stats
else:
# Accumulate last interval stats.
self.transfer_stats_accumulator = \
self.transfer_stats_accumulator.aggregate(transfer_stats)
def log(self, log_fn=logger.info):
"""Log transfer metrics periodically, similar to throughput logging"""
if (self.transfer_stats_accumulator
and not self.transfer_stats_accumulator.is_empty()):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics = self.transfer_stats_accumulator.reduce()
xfer_metrics_str = ", ".join(f"{k}={v}"
for k, v in xfer_metrics.items())
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
# Reset metrics for next interval
self.reset()

View File

@ -9,19 +9,21 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves: Optional[dict[str, int]] = None extra_async_saves: Optional[dict[str, int]] = None
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
for connector_id, stats in other.data.items():
if connector_id not in self.data:
self[connector_id] = stats
else:
assert isinstance(stats, type(self.data[connector_id]))
self[connector_id] = self[connector_id].aggregate(stats)
return self
def reset(self):
for stats in self.data.values():
stats.reset()
def reduce(self) -> dict[str, Any]:
# TODO (NickLucche) Adjust for logging on separate lines
return {
connector_id: stats.reduce()
for connector_id, stats in self.data.items()
}
def is_empty(self) -> bool:
return all(stats.is_empty() for stats in self.data.values())
def __getitem__(self, connector_id: str) -> KVConnectorStats:
return self.data[connector_id]
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats
class MultiConnector(KVConnectorBase_V1): class MultiConnector(KVConnectorBase_V1):
""" """
A wrapper for using multiple KVConnectors at the same time. A wrapper for using multiple KVConnectors at the same time.
@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = [] self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors") "connectors")
assert ktcs is not None assert ktcs is not None
@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
**ktc, engine_id=engine_id) **ktc, engine_id=engine_id)
self._connectors.append( self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role)) KVConnectorFactory.create_connector(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to # A mapping from request id to the index of the connector chosen to
# load the request from (if any). # load the request from (if any).
@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]: def take_events(self) -> Iterable["KVCacheEvent"]:
for c in self._connectors: for c in self._connectors:
yield from c.take_events() yield from c.take_events()
@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
f"({', '.join(layouts) })." f"({', '.join(layouts) })."
f"All connectors must use the same layout.") f"All connectors must use the same layout.")
return next(iter(layouts), None) return next(iter(layouts), None)
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional[KVConnectorStats]:
return MultiKVConnectorStats(data=data) if data is not None \
else MultiKVConnectorStats()
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
# Group connector stats by connector type.
stats_by_connector: Optional[MultiKVConnectorStats] = None
for c in self._connectors:
stats = c.get_kv_connector_stats()
if stats is None:
continue
if stats_by_connector is None:
# Lazy init to allow optional return value.
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import copy
import logging import logging
import math import math
import queue import queue
@ -11,7 +12,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Union
import msgspec import msgspec
import numpy as np import numpy as np
@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group) get_tp_group)
@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import make_zmq_path, make_zmq_socket from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
return self.connector_worker.get_finished() return self.connector_worker.get_finished()
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
assert self.connector_worker is not None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional[KVConnectorStats]:
return NixlKVConnectorStats(data=data) if data is not None \
else NixlKVConnectorStats()
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None: **kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
@ -377,6 +391,7 @@ class NixlConnectorScheduler:
Once a request is finished, determine whether request blocks Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later. should be freed now or will be sent asynchronously and freed later.
""" """
from vllm.v1.request import RequestStatus
params = request.kv_transfer_params params = request.kv_transfer_params
logger.debug( logger.debug(
@ -550,6 +565,7 @@ class NixlConnectorWorker:
# With heterogeneous TP, P must wait for all assigned D TP workers to # With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks. # finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
def __del__(self): def __del__(self):
"""Cleanup background threads on destruction.""" """Cleanup background threads on destruction."""
@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE": if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self.xfer_stats.record_transfer()
elif xfer_state == "PROC": elif xfer_state == "PROC":
in_progress = True in_progress = True
continue continue
@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
self.nixl_wrapper.transfer(handle) self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step(). # Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append( self._recving_transfers[request_id].append(
(handle, time.perf_counter())) (handle, time.perf_counter()))
@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
block_len = self.block_len block_len = self.block_len
return block_len return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
"""
Get the KV transfer stats for the connector.
"""
# Clear stats for next iteration
if not self.xfer_stats.is_empty():
return self.xfer_stats.clone_and_reset()
return None
@contextlib.contextmanager @contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
finally: finally:
if ctx is not None: if ctx is not None:
ctx.destroy(linger=0) ctx.destroy(linger=0)
@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if "num_successful_transfers" not in self.data:
self.data["num_successful_transfers"] = 0
def reset(self):
self.data = {"num_successful_transfers": 0}
def record_transfer(self):
# TODO: record actual transfer stats when available
self.data["num_successful_transfers"] += 1
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old
def is_empty(self) -> bool:
return self.data["num_successful_transfers"] == 0
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
self.data["num_successful_transfers"] += other.data[
"num_successful_transfers"]
return self
def reduce(self) -> dict[str, Union[int, float]]:
# TODO: reduce stats to a single value, calculate latency/throughput
return {
"num_successful_transfers": self.data["num_successful_transfers"]
}

View File

@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole) KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats = (kv_connector_output.kv_connector_stats
if kv_connector_output else None)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best # the below loop can be a performance bottleneck. We should do our best
@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
finished_requests=finished_set) finished_requests=finished_set)
finished_req_ids.clear() finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats)) is not None: if (stats := self.make_stats(spec_decoding_stats,
kv_connector_stats)) is not None:
# Return stats to only one of the front-ends. # Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None: if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request # We must return the stats even if there are no request
@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
def make_stats( def make_stats(
self, self,
spec_decoding_stats: Optional[SpecDecodingStats] = None, spec_decoding_stats: Optional[SpecDecodingStats] = None,
kv_connector_stats: Optional[KVConnectorStats] = None,
) -> Optional[SchedulerStats]: ) -> Optional[SchedulerStats]:
if not self.log_stats: if not self.log_stats:
return None return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None assert prefix_cache_stats is not None
return SchedulerStats( return SchedulerStats(num_running_reqs=len(self.running),
num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting),
num_waiting_reqs=len(self.waiting), kv_cache_usage=self.kv_cache_manager.usage,
kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats,
prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats,
spec_decoding_stats=spec_decoding_stats, num_corrupted_reqs=sum(req.is_output_corrupted
num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running),
for req in self.running), kv_connector_stats=kv_connector_stats.data
) if kv_connector_stats else None)
def make_spec_decoding_stats( def make_spec_decoding_stats(
self, self,

View File

@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
import prometheus_client import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging() self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
self.last_prompt_throughput: float = 0.0 self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0 self.last_generation_throughput: float = 0.0
@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe( self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats) scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_transfer_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_transfer_logging.log(log_fn=log_fn)
def log_engine_initialized(self): def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks: if self.vllm_config.cache_config.num_gpu_blocks:

View File

@ -3,7 +3,7 @@
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
@ -43,6 +43,7 @@ class SchedulerStats:
default_factory=PrefixCacheStats) default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats: Optional[dict[str, Any]] = None
num_corrupted_reqs: int = 0 num_corrupted_reqs: int = 0

View File

@ -3,10 +3,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import NamedTuple, Optional from typing import TYPE_CHECKING, NamedTuple, Optional
import torch import torch
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
@ -77,6 +81,11 @@ class KVConnectorOutput:
# [req_ids] # [req_ids]
finished_sending: Optional[set[str]] = None finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None finished_recving: Optional[set[str]] = None
kv_connector_stats: Optional["KVConnectorStats"] = None
def is_empty(self):
return (not self.finished_sending and not self.finished_recving
and not self.kv_connector_stats)
# ModelRunnerOutput is serialized and sent to the scheduler process. # ModelRunnerOutput is serialized and sent to the scheduler process.

View File

@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group, get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
output.finished_sending, output.finished_recving = ( output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids)) kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata() output.kv_connector_stats = KVConnectorModelRunnerMixin.\
get_kv_connector_stats()
@staticmethod
def get_kv_connector_stats() -> Optional[KVConnectorStats]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_kv_connector_stats()
return None