mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:24:29 +08:00
[P/D][Nixl] Introduce KVTransferMetrics and aggregation strategy (#22188)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
058525b997
commit
a3d087adec
@ -18,12 +18,18 @@ import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiKVConnectorStats)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
NixlConnectorWorker, NixlKVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -475,6 +481,209 @@ class TestNixlHandshake:
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
# the rest of the tests.
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_kv_connector_stats(dist_init):
|
||||
"""Test that KV transfer stats are properly recorded and retrieved."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
|
||||
# Verify that xfer_stats starts empty
|
||||
initial_stats = connector.get_kv_connector_stats()
|
||||
assert initial_stats is None
|
||||
|
||||
# Create transfer metadata
|
||||
request_id = "test_req_for_stats"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
# Start the transfer
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# Verify stats are recorded after transfer is complete
|
||||
max_iterations = 2
|
||||
# Clear metadata before start_load_kv to prevent reprocessing same request
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
for _ in range(max_iterations):
|
||||
# Need to call start_load_kv to process completed handshakes
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0 and request_id in done_recving:
|
||||
break
|
||||
time.sleep(
|
||||
0.1) # Small delay to allow background handshake to complete
|
||||
else:
|
||||
assert "Transfer did not complete within expected iterations"
|
||||
|
||||
# Now check that stats were recorded
|
||||
stats_after_transfer = connector.get_kv_connector_stats()
|
||||
assert isinstance(stats_after_transfer, NixlKVConnectorStats)
|
||||
|
||||
# Verify stats values are recorded
|
||||
assert not stats_after_transfer.is_empty()
|
||||
assert stats_after_transfer.data["num_successful_transfers"] == 1
|
||||
|
||||
# Verify stats are reset after retrieval
|
||||
stats_after_reset = connector.get_kv_connector_stats()
|
||||
assert stats_after_reset is None
|
||||
|
||||
|
||||
def test_kv_connector_stats_aggregation():
|
||||
"""
|
||||
Test KV transfer stats aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
|
||||
# done in MultiprocExecutor.execute_model
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
|
||||
# Create stats for multiple workers with different transfer patterns
|
||||
worker1_stats = NixlKVConnectorStats()
|
||||
worker2_stats = NixlKVConnectorStats()
|
||||
worker3_stats = NixlKVConnectorStats()
|
||||
|
||||
# Record different transfers on each worker
|
||||
# Worker 1: 2 transfers
|
||||
worker1_stats.record_transfer()
|
||||
worker1_stats.record_transfer()
|
||||
|
||||
# Worker 2: 1 transfer
|
||||
worker2_stats.record_transfer()
|
||||
|
||||
# Worker 3: 3 transfers
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_stats in enumerate(
|
||||
[worker1_stats, worker2_stats, worker3_stats]):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2 else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0 else None, # Workers 1,2 finished receiving
|
||||
kv_connector_stats=worker_stats,
|
||||
))
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
||||
# Number of total transfers across all workers.
|
||||
assert kv_connector_stats.data["num_successful_transfers"] == 6
|
||||
|
||||
|
||||
def test_multi_kv_connector_stats_aggregation():
|
||||
"""
|
||||
Test MultiKVConnectorStats aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class FooKVConnectorStats(KVConnectorStats):
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_foo_transfers": 0}
|
||||
|
||||
def record_transfer(self):
|
||||
if "num_foo_transfers" not in self.data:
|
||||
self.data["num_foo_transfers"] = 0
|
||||
self.data["num_foo_transfers"] += 1
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_foo_transfers"] == 0
|
||||
|
||||
def aggregate(self,
|
||||
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
|
||||
if not other.is_empty():
|
||||
self.data["num_foo_transfers"] += other.data[
|
||||
"num_foo_transfers"]
|
||||
return self
|
||||
|
||||
def make_multi_stats(nixl_count: int,
|
||||
foo_count: int) -> MultiKVConnectorStats:
|
||||
data: dict[str, KVConnectorStats] = {}
|
||||
if nixl_count > 0:
|
||||
nixl_stats = NixlKVConnectorStats()
|
||||
for _ in range(nixl_count):
|
||||
nixl_stats.record_transfer()
|
||||
data["NixlConnector"] = nixl_stats
|
||||
if foo_count > 0:
|
||||
foo_stats = FooKVConnectorStats()
|
||||
for _ in range(foo_count):
|
||||
foo_stats.record_transfer()
|
||||
data["FooConnector"] = foo_stats
|
||||
return MultiKVConnectorStats(data=data)
|
||||
|
||||
# Create heterogeneous stats across 3 workers
|
||||
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
|
||||
|
||||
worker_outputs: list[ModelRunnerOutput] = []
|
||||
for i, (nixl, foo) in enumerate(worker_patterns):
|
||||
stats = make_multi_stats(nixl, foo)
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
|
||||
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
|
||||
kv_connector_stats=stats,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
||||
|
||||
# Validate per-connector totals across workers
|
||||
assert kv_connector_stats["NixlConnector"].data[
|
||||
"num_successful_transfers"] == 5
|
||||
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
|
||||
@ -129,7 +129,7 @@ class KVOutputAggregator:
|
||||
def aggregate(self,
|
||||
outputs: list[ModelRunnerOutput],
|
||||
output_rank: int = 0) -> ModelRunnerOutput:
|
||||
# aggregate kv_connector_output from all workers
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(req_ids: Optional[set[str]],
|
||||
remaining_count_dict: dict[str, int],
|
||||
@ -142,8 +142,9 @@ class KVOutputAggregator:
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
output = output.kv_connector_output
|
||||
aggregated_kv_connector_stats = None
|
||||
for model_runner_output in outputs:
|
||||
output = model_runner_output.kv_connector_output
|
||||
if not output:
|
||||
continue
|
||||
update_finished_set(output.finished_sending,
|
||||
@ -151,12 +152,26 @@ class KVOutputAggregator:
|
||||
update_finished_set(output.finished_recving,
|
||||
self._recv_remaining_count, finished_recving)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = output.kv_connector_stats
|
||||
elif kv_connector_stats := output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(aggregated_kv_connector_stats,
|
||||
type(kv_connector_stats))
|
||||
aggregated_kv_connector_stats = \
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -49,6 +49,8 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -365,4 +373,16 @@ class KVConnectorBase_V1(ABC):
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
100
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
100
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
has_kv_transfer_group)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorStats:
|
||||
"""
|
||||
Base class for KV Connector Stats, a container for transfer performance
|
||||
metrics or otherwise important telemetry from the connector.
|
||||
All sub-classes need to be serializable as stats are sent from worker to
|
||||
logger process.
|
||||
"""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the stats, clear the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
|
||||
"""
|
||||
Aggregate stats with another `KVConnectorStats` object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce(self) -> dict[str, Union[int, float]]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the stats are empty."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorLogging:
|
||||
|
||||
def __init__(self, kv_tranfer_config: KVTransferConfig):
|
||||
# This should be called on frontend process.
|
||||
assert not has_kv_transfer_group()
|
||||
# Instantiate the connector's stats class.
|
||||
if kv_tranfer_config and kv_tranfer_config.kv_connector:
|
||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||
kv_tranfer_config)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transfer_stats_accumulator: Optional[KVConnectorStats] = None
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any]):
|
||||
# Should not be called when a KVConnector is not configured.
|
||||
assert self.connector_cls is not None
|
||||
# Called periodically when connector syncs with the scheduler.
|
||||
# Note that this is not the same as the logging interval.
|
||||
# We expect transfer_stats_data to be aggregated across all workers and
|
||||
# consist of observations from a single connector or a MultiConnector.
|
||||
transfer_stats = self.connector_cls.build_kv_connector_stats(
|
||||
transfer_stats_data)
|
||||
if transfer_stats is None:
|
||||
logger.warning_once(
|
||||
"The connector %s is collecting stats but "
|
||||
"does not implement the "
|
||||
"`build_kv_connector_stats` method. "
|
||||
"Stats will not be logged.", self.connector_cls)
|
||||
return
|
||||
|
||||
if self.transfer_stats_accumulator is None:
|
||||
self.transfer_stats_accumulator = transfer_stats
|
||||
else:
|
||||
# Accumulate last interval stats.
|
||||
self.transfer_stats_accumulator = \
|
||||
self.transfer_stats_accumulator.aggregate(transfer_stats)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
"""Log transfer metrics periodically, similar to throughput logging"""
|
||||
if (self.transfer_stats_accumulator
|
||||
and not self.transfer_stats_accumulator.is_empty()):
|
||||
# Produce a single cumulative stats object for the last time
|
||||
# interval from the recorded observations.
|
||||
xfer_metrics = self.transfer_stats_accumulator.reduce()
|
||||
xfer_metrics_str = ", ".join(f"{k}={v}"
|
||||
for k, v in xfer_metrics.items())
|
||||
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
||||
@ -9,19 +9,21 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
extra_async_saves: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
Maintain a dict of KVConnectorStats objects, one for each connector.
|
||||
This is used to aggregate the stats from all connectors separately.
|
||||
"""
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
for connector_id, stats in other.data.items():
|
||||
if connector_id not in self.data:
|
||||
self[connector_id] = stats
|
||||
else:
|
||||
assert isinstance(stats, type(self.data[connector_id]))
|
||||
self[connector_id] = self[connector_id].aggregate(stats)
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
for stats in self.data.values():
|
||||
stats.reset()
|
||||
|
||||
def reduce(self) -> dict[str, Any]:
|
||||
# TODO (NickLucche) Adjust for logging on separate lines
|
||||
return {
|
||||
connector_id: stats.reduce()
|
||||
for connector_id, stats in self.data.items()
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return all(stats.is_empty() for stats in self.data.values())
|
||||
|
||||
def __getitem__(self, connector_id: str) -> KVConnectorStats:
|
||||
return self.data[connector_id]
|
||||
|
||||
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
**ktc, engine_id=engine_id)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector(temp_config, role))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
f"({', '.join(layouts) })."
|
||||
f"All connectors must use the same layout.")
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional[KVConnectorStats]:
|
||||
return MultiKVConnectorStats(data=data) if data is not None \
|
||||
else MultiKVConnectorStats()
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
|
||||
# Group connector stats by connector type.
|
||||
stats_by_connector: Optional[MultiKVConnectorStats] = None
|
||||
for c in self._connectors:
|
||||
stats = c.get_kv_connector_stats()
|
||||
if stats is None:
|
||||
continue
|
||||
if stats_by_connector is None:
|
||||
# Lazy init to allow optional return value.
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
@ -11,7 +12,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_kv_connector_stats()
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional[KVConnectorStats]:
|
||||
return NixlKVConnectorStats(data=data) if data is not None \
|
||||
else NixlKVConnectorStats()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
@ -377,6 +391,7 @@ class NixlConnectorScheduler:
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
@ -550,6 +565,7 @@ class NixlConnectorWorker:
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
# finish reading before safely freeing the blocks.
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
# TODO (NickLucche) Get from NIXL telemetry once integrated
|
||||
self.xfer_stats.record_transfer()
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
continue
|
||||
@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
self._recving_transfers[request_id].append(
|
||||
(handle, time.perf_counter()))
|
||||
|
||||
@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
|
||||
block_len = self.block_len
|
||||
return block_len
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
# Clear stats for next iteration
|
||||
if not self.xfer_stats.is_empty():
|
||||
return self.xfer_stats.clone_and_reset()
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlKVConnectorStats(KVConnectorStats):
|
||||
"""Container for transfer performance metrics"""
|
||||
|
||||
def __post_init__(self):
|
||||
if "num_successful_transfers" not in self.data:
|
||||
self.data["num_successful_transfers"] = 0
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_successful_transfers": 0}
|
||||
|
||||
def record_transfer(self):
|
||||
# TODO: record actual transfer stats when available
|
||||
self.data["num_successful_transfers"] += 1
|
||||
|
||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||
old = copy.copy(self)
|
||||
self.reset()
|
||||
return old
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_successful_transfers"] == 0
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
self.data["num_successful_transfers"] += other.data[
|
||||
"num_successful_transfers"]
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, Union[int, float]]:
|
||||
# TODO: reduce stats to a single value, calculate latency/throughput
|
||||
return {
|
||||
"num_successful_transfers": self.data["num_successful_transfers"]
|
||||
}
|
||||
@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||
kv_connector_output = model_runner_output.kv_connector_output
|
||||
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
kv_connector_stats = (kv_connector_output.kv_connector_stats
|
||||
if kv_connector_output else None)
|
||||
|
||||
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
|
||||
# the below loop can be a performance bottleneck. We should do our best
|
||||
@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
|
||||
finished_requests=finished_set)
|
||||
finished_req_ids.clear()
|
||||
|
||||
if (stats := self.make_stats(spec_decoding_stats)) is not None:
|
||||
if (stats := self.make_stats(spec_decoding_stats,
|
||||
kv_connector_stats)) is not None:
|
||||
# Return stats to only one of the front-ends.
|
||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||
# We must return the stats even if there are no request
|
||||
@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
|
||||
def make_stats(
|
||||
self,
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None,
|
||||
kv_connector_stats: Optional[KVConnectorStats] = None,
|
||||
) -> Optional[SchedulerStats]:
|
||||
if not self.log_stats:
|
||||
return None
|
||||
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
|
||||
assert prefix_cache_stats is not None
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
spec_decoding_stats=spec_decoding_stats,
|
||||
num_corrupted_reqs=sum(req.is_output_corrupted
|
||||
for req in self.running),
|
||||
)
|
||||
return SchedulerStats(num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
spec_decoding_stats=spec_decoding_stats,
|
||||
num_corrupted_reqs=sum(req.is_output_corrupted
|
||||
for req in self.running),
|
||||
kv_connector_stats=kv_connector_stats.data
|
||||
if kv_connector_stats else None)
|
||||
|
||||
def make_spec_decoding_stats(
|
||||
self,
|
||||
|
||||
@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
|
||||
import prometheus_client
|
||||
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorLogging)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
# TODO: Make the interval configurable.
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
||||
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
|
||||
@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
if scheduler_stats.spec_decoding_stats is not None:
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
|
||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||
self.kv_transfer_logging.observe(kv_connector_stats)
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
self.kv_transfer_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
|
||||
@ -43,6 +43,7 @@ class SchedulerStats:
|
||||
default_factory=PrefixCacheStats)
|
||||
|
||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||
kv_connector_stats: Optional[dict[str, Any]] = None
|
||||
|
||||
num_corrupted_reqs: int = 0
|
||||
|
||||
|
||||
@ -3,10 +3,14 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
@ -77,6 +81,11 @@ class KVConnectorOutput:
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
kv_connector_stats: Optional["KVConnectorStats"] = None
|
||||
|
||||
def is_empty(self):
|
||||
return (not self.finished_sending and not self.finished_recving
|
||||
and not self.kv_connector_stats)
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||
@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids))
|
||||
|
||||
kv_connector.clear_connector_metadata()
|
||||
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
|
||||
get_kv_connector_stats()
|
||||
|
||||
@staticmethod
|
||||
def get_kv_connector_stats() -> Optional[KVConnectorStats]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_kv_connector_stats()
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user