[P/D][Nixl] Introduce KVTransferMetrics and aggregation strategy (#22188)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-19 13:09:14 +02:00 committed by GitHub
parent 058525b997
commit a3d087adec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 525 additions and 25 deletions

View File

@ -18,12 +18,18 @@ import torch
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from .utils import create_request, create_scheduler, create_vllm_config
@ -475,6 +481,209 @@ class TestNixlHandshake:
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_kv_connector_stats(dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
# Verify that xfer_stats starts empty
initial_stats = connector.get_kv_connector_stats()
assert initial_stats is None
# Create transfer metadata
request_id = "test_req_for_stats"
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)
# Start the transfer
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
# Verify stats are recorded after transfer is complete
max_iterations = 2
# Clear metadata before start_load_kv to prevent reprocessing same request
connector.bind_connector_metadata(NixlConnectorMetadata())
for _ in range(max_iterations):
# Need to call start_load_kv to process completed handshakes
connector.start_load_kv(dummy_ctx)
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0 and request_id in done_recving:
break
time.sleep(
0.1) # Small delay to allow background handshake to complete
else:
assert "Transfer did not complete within expected iterations"
# Now check that stats were recorded
stats_after_transfer = connector.get_kv_connector_stats()
assert isinstance(stats_after_transfer, NixlKVConnectorStats)
# Verify stats values are recorded
assert not stats_after_transfer.is_empty()
assert stats_after_transfer.data["num_successful_transfers"] == 1
# Verify stats are reset after retrieval
stats_after_reset = connector.get_kv_connector_stats()
assert stats_after_reset is None
def test_kv_connector_stats_aggregation():
"""
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3)
# Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats()
worker2_stats = NixlKVConnectorStats()
worker3_stats = NixlKVConnectorStats()
# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats.record_transfer()
worker1_stats.record_transfer()
# Worker 2: 1 transfer
worker2_stats.record_transfer()
# Worker 3: 3 transfers
worker3_stats.record_transfer()
worker3_stats.record_transfer()
worker3_stats.record_transfer()
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_stats in enumerate(
[worker1_stats, worker2_stats, worker3_stats]):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2 else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0 else None, # Workers 1,2 finished receiving
kv_connector_stats=worker_stats,
))
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers.
assert kv_connector_stats.data["num_successful_transfers"] == 6
def test_multi_kv_connector_stats_aggregation():
"""
Test MultiKVConnectorStats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator = KVOutputAggregator(world_size=3)
from dataclasses import dataclass
@dataclass
class FooKVConnectorStats(KVConnectorStats):
def reset(self):
self.data = {"num_foo_transfers": 0}
def record_transfer(self):
if "num_foo_transfers" not in self.data:
self.data["num_foo_transfers"] = 0
self.data["num_foo_transfers"] += 1
def is_empty(self) -> bool:
return self.data["num_foo_transfers"] == 0
def aggregate(self,
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
if not other.is_empty():
self.data["num_foo_transfers"] += other.data[
"num_foo_transfers"]
return self
def make_multi_stats(nixl_count: int,
foo_count: int) -> MultiKVConnectorStats:
data: dict[str, KVConnectorStats] = {}
if nixl_count > 0:
nixl_stats = NixlKVConnectorStats()
for _ in range(nixl_count):
nixl_stats.record_transfer()
data["NixlConnector"] = nixl_stats
if foo_count > 0:
foo_stats = FooKVConnectorStats()
for _ in range(foo_count):
foo_stats.record_transfer()
data["FooConnector"] = foo_stats
return MultiKVConnectorStats(data=data)
# Create heterogeneous stats across 3 workers
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
worker_outputs: list[ModelRunnerOutput] = []
for i, (nixl, foo) in enumerate(worker_patterns):
stats = make_multi_stats(nixl, foo)
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
kv_connector_stats=stats,
),
)
worker_outputs.append(output)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_connector_stats = \
aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
# Validate per-connector totals across workers
assert kv_connector_stats["NixlConnector"].data[
"num_successful_transfers"] == 5
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",

View File

@ -129,7 +129,7 @@ class KVOutputAggregator:
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate kv_connector_output from all workers
# Aggregate kv_connector_output from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
@ -142,8 +142,9 @@ class KVOutputAggregator:
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
aggregated_kv_connector_stats = None
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending,
@ -151,12 +152,26 @@ class KVOutputAggregator:
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(aggregated_kv_connector_stats,
type(kv_connector_stats))
aggregated_kv_connector_stats = \
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
# select output of the worker specified by output_rank
output = outputs[output_rank]
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
)
return output

View File

@ -49,6 +49,8 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
"""
return None
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
"""
Get the KV connector stats collected during the last interval.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
@ -365,4 +373,16 @@ class KVConnectorBase_V1(ABC):
int: expected sending or receiving completion count.
"""
return None
return None
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return None

View File

@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class KVConnectorStats:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data: dict[str, Any] = field(default_factory=dict)
def reset(self):
"""Reset the stats, clear the state."""
raise NotImplementedError
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise NotImplementedError
def reduce(self) -> dict[str, Union[int, float]]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise NotImplementedError
def is_empty(self) -> bool:
"""Return True if the stats are empty."""
raise NotImplementedError
class KVConnectorLogging:
def __init__(self, kv_tranfer_config: KVTransferConfig):
# This should be called on frontend process.
assert not has_kv_transfer_group()
# Instantiate the connector's stats class.
if kv_tranfer_config and kv_tranfer_config.kv_connector:
self.connector_cls = KVConnectorFactory.get_connector_class(
kv_tranfer_config)
self.reset()
def reset(self):
self.transfer_stats_accumulator: Optional[KVConnectorStats] = None
def observe(self, transfer_stats_data: dict[str, Any]):
# Should not be called when a KVConnector is not configured.
assert self.connector_cls is not None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats = self.connector_cls.build_kv_connector_stats(
transfer_stats_data)
if transfer_stats is None:
logger.warning_once(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged.", self.connector_cls)
return
if self.transfer_stats_accumulator is None:
self.transfer_stats_accumulator = transfer_stats
else:
# Accumulate last interval stats.
self.transfer_stats_accumulator = \
self.transfer_stats_accumulator.aggregate(transfer_stats)
def log(self, log_fn=logger.info):
"""Log transfer metrics periodically, similar to throughput logging"""
if (self.transfer_stats_accumulator
and not self.transfer_stats_accumulator.is_empty()):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics = self.transfer_stats_accumulator.reduce()
xfer_metrics_str = ", ".join(f"{k}={v}"
for k, v in xfer_metrics.items())
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
# Reset metrics for next interval
self.reset()

View File

@ -9,19 +9,21 @@ import torch
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves: Optional[dict[str, int]] = None
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
for connector_id, stats in other.data.items():
if connector_id not in self.data:
self[connector_id] = stats
else:
assert isinstance(stats, type(self.data[connector_id]))
self[connector_id] = self[connector_id].aggregate(stats)
return self
def reset(self):
for stats in self.data.values():
stats.reset()
def reduce(self) -> dict[str, Any]:
# TODO (NickLucche) Adjust for logging on separate lines
return {
connector_id: stats.reduce()
for connector_id, stats in self.data.items()
}
def is_empty(self) -> bool:
return all(stats.is_empty() for stats in self.data.values())
def __getitem__(self, connector_id: str) -> KVConnectorStats:
return self.data[connector_id]
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
**ktc, engine_id=engine_id)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
def take_events(self) -> Iterable["KVCacheEvent"]:
for c in self._connectors:
yield from c.take_events()
@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
f"({', '.join(layouts) })."
f"All connectors must use the same layout.")
return next(iter(layouts), None)
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional[KVConnectorStats]:
return MultiKVConnectorStats(data=data) if data is not None \
else MultiKVConnectorStats()
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
# Group connector stats by connector type.
stats_by_connector: Optional[MultiKVConnectorStats] = None
for c in self._connectors:
stats = c.get_kv_connector_stats()
if stats is None:
continue
if stats_by_connector is None:
# Lazy init to allow optional return value.
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import logging
import math
import queue
@ -11,7 +12,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
import msgspec
import numpy as np
@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
assert self.connector_worker is not None
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls,
data: Optional[dict[str,
Any]] = None) -> Optional[KVConnectorStats]:
return NixlKVConnectorStats(data=data) if data is not None \
else NixlKVConnectorStats()
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
@ -377,6 +391,7 @@ class NixlConnectorScheduler:
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
from vllm.v1.request import RequestStatus
params = request.kv_transfer_params
logger.debug(
@ -550,6 +565,7 @@ class NixlConnectorWorker:
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
def __del__(self):
"""Cleanup background threads on destruction."""
@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self.xfer_stats.record_transfer()
elif xfer_state == "PROC":
in_progress = True
continue
@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
self.nixl_wrapper.transfer(handle)
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
block_len = self.block_len
return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
"""
Get the KV transfer stats for the connector.
"""
# Clear stats for next iteration
if not self.xfer_stats.is_empty():
return self.xfer_stats.clone_and_reset()
return None
@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
finally:
if ctx is not None:
ctx.destroy(linger=0)
@dataclass
class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
def __post_init__(self):
if "num_successful_transfers" not in self.data:
self.data["num_successful_transfers"] = 0
def reset(self):
self.data = {"num_successful_transfers": 0}
def record_transfer(self):
# TODO: record actual transfer stats when available
self.data["num_successful_transfers"] += 1
def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self)
self.reset()
return old
def is_empty(self) -> bool:
return self.data["num_successful_transfers"] == 0
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
self.data["num_successful_transfers"] += other.data[
"num_successful_transfers"]
return self
def reduce(self) -> dict[str, Union[int, float]]:
# TODO: reduce stats to a single value, calculate latency/throughput
return {
"num_successful_transfers": self.data["num_successful_transfers"]
}

View File

@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats = (kv_connector_output.kv_connector_stats
if kv_connector_output else None)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
finished_requests=finished_set)
finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats)) is not None:
if (stats := self.make_stats(spec_decoding_stats,
kv_connector_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
kv_connector_stats: Optional[KVConnectorStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted
for req in self.running),
)
return SchedulerStats(num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted
for req in self.running),
kv_connector_stats=kv_connector_stats.data
if kv_connector_stats else None)
def make_spec_decoding_stats(
self,

View File

@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_transfer_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats
def log(self):
@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_transfer_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:

View File

@ -3,7 +3,7 @@
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from vllm.v1.spec_decode.metrics import SpecDecodingStats
@ -43,6 +43,7 @@ class SchedulerStats:
default_factory=PrefixCacheStats)
spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats: Optional[dict[str, Any]] = None
num_corrupted_reqs: int = 0

View File

@ -3,10 +3,14 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional
import torch
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
class LogprobsLists(NamedTuple):
@ -77,6 +81,11 @@ class KVConnectorOutput:
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_stats: Optional["KVConnectorStats"] = None
def is_empty(self):
return (not self.finished_sending and not self.finished_recving
and not self.kv_connector_stats)
# ModelRunnerOutput is serialized and sent to the scheduler process.

View File

@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata()
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
get_kv_connector_stats()
@staticmethod
def get_kv_connector_stats() -> Optional[KVConnectorStats]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_kv_connector_stats()
return None