[NIXL][Misc] Expose metrics from NIXL for logging to CLI (#25388)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-03 12:47:59 +02:00 committed by GitHub
parent 0e93ac0b3a
commit 48f309029a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 127 additions and 28 deletions

View File

@ -1,2 +1,2 @@
lmcache lmcache
nixl >= 0.5.1 # Required for disaggregated prefill nixl >= 0.6.0 # Required for disaggregated prefill

View File

@ -57,6 +57,26 @@ def clear_kv_transfer():
ensure_kv_transfer_shutdown() ensure_kv_transfer_shutdown()
def get_default_xfer_telemetry(xferDurationS: float = 1,
postDurationS: float = 1,
totalBytes: int = 1,
descCount: int = 1) -> dict:
class AttributeDict(dict):
__slots__ = ()
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__ # type: ignore[assignment]
# We can't instantiate nixlXferTelemetry because it's read only and
# ray env does not have NIXL, so we must fake it
return AttributeDict(
xferDuration=xferDurationS * 1e6, # in us
postDuration=postDurationS * 1e6, # in us
totalBytes=totalBytes,
descCount=descCount,
)
class FakeNixlWrapper: class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing. """Mock implementation of NixlWrapper for testing.
@ -132,6 +152,9 @@ class FakeNixlWrapper:
def transfer(self, handle: int) -> str: def transfer(self, handle: int) -> str:
return "PROC" return "PROC"
def get_xfer_telemetry(self, handle: int) -> dict:
return get_default_xfer_telemetry()
############################################################ ############################################################
# Follow are for changing the behavior during testing. # Follow are for changing the behavior during testing.
############################################################ ############################################################
@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper
with open(os.path.join(pkg_root, "__init__.py"), "w") as f: with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub) f.write(stub)
# Mock nixlXferTelemetry class
pkg_root2 = os.path.join(td, "nixl", "_bindings")
os.makedirs(pkg_root2, exist_ok=True)
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
f.write("class nixlXferTelemetry: pass")
# touch parent package # touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close() open(os.path.join(td, "nixl", "__init__.py"), "w").close()
yield td yield td
@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
# Verify stats values are recorded # Verify stats values are recorded
assert not stats_after_transfer.is_empty() assert not stats_after_transfer.is_empty()
assert stats_after_transfer.data["num_successful_transfers"] == 1 assert stats_after_transfer.num_successful_transfers == 1
# Verify stats are reset after retrieval # Verify stats are reset after retrieval
stats_after_reset = connector.get_kv_connector_stats() stats_after_reset = connector.get_kv_connector_stats()
@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
# Record different transfers on each worker # Record different transfers on each worker
# Worker 1: 2 transfers # Worker 1: 2 transfers
worker1_stats.record_transfer() stats = get_default_xfer_telemetry()
worker1_stats.record_transfer() worker1_stats.record_transfer(stats)
worker1_stats.record_transfer(stats)
# Worker 2: 1 transfer # Worker 2: 1 transfer
worker2_stats.record_transfer() worker2_stats.record_transfer(stats)
# Worker 3: 3 transfers # Worker 3: 3 transfers
worker3_stats.record_transfer() stats = get_default_xfer_telemetry(xferDurationS=2,
worker3_stats.record_transfer() postDurationS=2,
worker3_stats.record_transfer() totalBytes=2,
descCount=2)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)
worker3_stats.record_transfer(stats)
# Create ModelRunnerOutput instances for each worker # Create ModelRunnerOutput instances for each worker
worker_outputs = [] worker_outputs = []
@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
aggregated_output.kv_connector_output.kv_connector_stats aggregated_output.kv_connector_output.kv_connector_stats
assert isinstance(kv_connector_stats, NixlKVConnectorStats) assert isinstance(kv_connector_stats, NixlKVConnectorStats)
# Number of total transfers across all workers. # Number of total transfers across all workers.
assert kv_connector_stats.data["num_successful_transfers"] == 6 assert kv_connector_stats.num_successful_transfers == 6
# Logging proc, call reduce() to get CLI-friendly stats.
cli_stats = kv_connector_stats.reduce()
assert cli_stats["Avg xfer time (ms)"] == 1500.0
assert cli_stats["Avg post time (ms)"] == 1500.0
assert cli_stats["Avg number of descriptors"] == 1.5
def test_multi_kv_connector_stats_aggregation(): def test_multi_kv_connector_stats_aggregation():
@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
from dataclasses import dataclass from dataclasses import dataclass
# Mock a KVConnectorStats class for testing aggregation over connectors.
@dataclass @dataclass
class FooKVConnectorStats(KVConnectorStats): class FooKVConnectorStats(KVConnectorStats):
@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation():
if nixl_count > 0: if nixl_count > 0:
nixl_stats = NixlKVConnectorStats() nixl_stats = NixlKVConnectorStats()
for _ in range(nixl_count): for _ in range(nixl_count):
nixl_stats.record_transfer() nixl_stats.record_transfer(get_default_xfer_telemetry())
data["NixlConnector"] = nixl_stats data["NixlConnector"] = nixl_stats
if foo_count > 0: if foo_count > 0:
foo_stats = FooKVConnectorStats() foo_stats = FooKVConnectorStats()
@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation():
assert isinstance(kv_connector_stats, MultiKVConnectorStats) assert isinstance(kv_connector_stats, MultiKVConnectorStats)
# Validate per-connector totals across workers # Validate per-connector totals across workers
assert kv_connector_stats["NixlConnector"].data[ assert isinstance(kv_connector_stats["NixlConnector"],
"num_successful_transfers"] == 5 NixlKVConnectorStats)
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"working_dir": working_dir, # ship fake nixl package "working_dir": working_dir, # ship fake nixl package
"env_vars": { "env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
# TODO: for ray to carry over, remove once we set
"NIXL_TELEMETRY_ENABLE": "1",
}, },
} }
ray.init(runtime_env=runtime_env) ray.init(runtime_env=runtime_env)

View File

@ -4,6 +4,7 @@ import contextlib
import copy import copy
import logging import logging
import math import math
import os
import queue import queue
import threading import threading
import time import time
@ -54,10 +55,12 @@ logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try: try:
from nixl._api import nixl_agent as NixlWrapper from nixl._api import nixl_agent as NixlWrapper
from nixl._bindings import nixlXferTelemetry
logger.info("NIXL is available") logger.info("NIXL is available")
except ImportError: except ImportError:
logger.warning("NIXL is not available") logger.warning("NIXL is not available")
NixlWrapper = None NixlWrapper = None
nixlXferTelemetry = None
try: try:
from nixl._api import nixl_agent_config from nixl._api import nixl_agent_config
@ -476,6 +479,9 @@ class NixlConnectorWorker:
self.nixl_backends = \ self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config( vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]) "backends", ["UCX"])
# TODO temporary, once nixl allows for telemetry flag in config
# (next release), we can remove this env var.
os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
# Agent. # Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
if nixl_agent_config is None: if nixl_agent_config is None:
@ -1175,9 +1181,10 @@ class NixlConnectorWorker:
for handle, _xfer_stime in handles: for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE": if xfer_state == "DONE":
# Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self.xfer_stats.record_transfer()
elif xfer_state == "PROC": elif xfer_state == "PROC":
in_progress = True in_progress = True
continue continue
@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics""" """Container for transfer performance metrics"""
def __post_init__(self): def __post_init__(self):
if "num_successful_transfers" not in self.data: if not self.data:
self.data["num_successful_transfers"] = 0 # Empty container init, no data is passed in.
self.reset()
def reset(self): def reset(self):
self.data = {"num_successful_transfers": 0} # Must be serializable
self.data: dict[str, list[float]] = {
"transfer_duration": [],
"post_duration": [],
"bytes_transferred": [],
"num_descriptors": [],
}
def record_transfer(self): def record_transfer(self, res: nixlXferTelemetry):
# TODO: record actual transfer stats when available # Keep metrics units consistent with rest of the code: time us->s
self.data["num_successful_transfers"] += 1 self.data["transfer_duration"].append(res.xferDuration / 1e6)
self.data["post_duration"].append(res.postDuration / 1e6)
self.data["bytes_transferred"].append(res.totalBytes)
self.data["num_descriptors"].append(res.descCount)
def clone_and_reset(self) -> "NixlKVConnectorStats": def clone_and_reset(self) -> "NixlKVConnectorStats":
old = copy.copy(self) old = copy.copy(self)
@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats):
return old return old
def is_empty(self) -> bool: def is_empty(self) -> bool:
return self.data["num_successful_transfers"] == 0 return self.num_successful_transfers == 0
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty(): if not other.is_empty():
self.data["num_successful_transfers"] += other.data[ for k, v in other.data.items():
"num_successful_transfers"] accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self return self
def reduce(self) -> dict[str, Union[int, float]]: def reduce(self) -> dict[str, Union[int, float]]:
# TODO: reduce stats to a single value, calculate latency/throughput # Compute compact representative stats suitable for CLI logging
if self.is_empty():
return { return {
"num_successful_transfers": self.data["num_successful_transfers"] "Num successful transfers": 0,
"Avg xfer time (ms)": 0,
"P90 xfer time (ms)": 0,
"Avg post time (ms)": 0,
"P90 post time (ms)": 0,
"Avg MB per transfer": 0,
"Throughput (MB/s)": 0,
"Avg number of descriptors": 0,
} }
xfer_time = np.asarray(self.data["transfer_duration"])
post_time = np.asarray(self.data["post_duration"])
# Convert to MB for CLI logging.
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
n = len(descs)
assert n == self.num_successful_transfers
total_mb = mb.sum()
avg_mb = total_mb / n
total_time_seconds = xfer_time.sum()
throughput_mb_s = total_mb / total_time_seconds
return {
"Num successful transfers": n,
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
"Avg MB per transfer": round(avg_mb, 3),
"Throughput (MB/s)": round(throughput_mb_s, 3),
"Avg number of descriptors": round(descs.mean(), 1),
}
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])

View File

@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = PrefixCachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging() self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config) self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
self.last_prompt_throughput: float = 0.0 self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0 self.last_generation_throughput: float = 0.0
@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.observe( self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats) scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats: if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_transfer_logging.observe(kv_connector_stats) self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase):
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
) )
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_transfer_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn)
def log_engine_initialized(self): def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks: if self.vllm_config.cache_config.num_gpu_blocks: