diff --git a/requirements/kv_connectors.txt b/requirements/kv_connectors.txt index 3b610e0d9736..b1f3269cd381 100644 --- a/requirements/kv_connectors.txt +++ b/requirements/kv_connectors.txt @@ -1,2 +1,2 @@ lmcache -nixl >= 0.5.1 # Required for disaggregated prefill +nixl >= 0.6.0 # Required for disaggregated prefill diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 578bf02eb519..21953b5533ec 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -57,6 +57,26 @@ def clear_kv_transfer(): ensure_kv_transfer_shutdown() +def get_default_xfer_telemetry(xferDurationS: float = 1, + postDurationS: float = 1, + totalBytes: int = 1, + descCount: int = 1) -> dict: + + class AttributeDict(dict): + __slots__ = () + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ # type: ignore[assignment] + + # We can't instantiate nixlXferTelemetry because it's read only and + # ray env does not have NIXL, so we must fake it + return AttributeDict( + xferDuration=xferDurationS * 1e6, # in us + postDuration=postDurationS * 1e6, # in us + totalBytes=totalBytes, + descCount=descCount, + ) + + class FakeNixlWrapper: """Mock implementation of NixlWrapper for testing. @@ -132,6 +152,9 @@ class FakeNixlWrapper: def transfer(self, handle: int) -> str: return "PROC" + def get_xfer_telemetry(self, handle: int) -> dict: + return get_default_xfer_telemetry() + ############################################################ # Follow are for changing the behavior during testing. ############################################################ @@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper with open(os.path.join(pkg_root, "__init__.py"), "w") as f: f.write(stub) + # Mock nixlXferTelemetry class + pkg_root2 = os.path.join(td, "nixl", "_bindings") + os.makedirs(pkg_root2, exist_ok=True) + with open(os.path.join(pkg_root2, "__init__.py"), "w") as f: + f.write("class nixlXferTelemetry: pass") # touch parent package open(os.path.join(td, "nixl", "__init__.py"), "w").close() yield td @@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init): # Verify stats values are recorded assert not stats_after_transfer.is_empty() - assert stats_after_transfer.data["num_successful_transfers"] == 1 + assert stats_after_transfer.num_successful_transfers == 1 # Verify stats are reset after retrieval stats_after_reset = connector.get_kv_connector_stats() @@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation(): # Record different transfers on each worker # Worker 1: 2 transfers - worker1_stats.record_transfer() - worker1_stats.record_transfer() + stats = get_default_xfer_telemetry() + worker1_stats.record_transfer(stats) + worker1_stats.record_transfer(stats) # Worker 2: 1 transfer - worker2_stats.record_transfer() + worker2_stats.record_transfer(stats) # Worker 3: 3 transfers - worker3_stats.record_transfer() - worker3_stats.record_transfer() - worker3_stats.record_transfer() + stats = get_default_xfer_telemetry(xferDurationS=2, + postDurationS=2, + totalBytes=2, + descCount=2) + worker3_stats.record_transfer(stats) + worker3_stats.record_transfer(stats) + worker3_stats.record_transfer(stats) # Create ModelRunnerOutput instances for each worker worker_outputs = [] @@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation(): aggregated_output.kv_connector_output.kv_connector_stats assert isinstance(kv_connector_stats, NixlKVConnectorStats) # Number of total transfers across all workers. - assert kv_connector_stats.data["num_successful_transfers"] == 6 + assert kv_connector_stats.num_successful_transfers == 6 + # Logging proc, call reduce() to get CLI-friendly stats. + cli_stats = kv_connector_stats.reduce() + assert cli_stats["Avg xfer time (ms)"] == 1500.0 + assert cli_stats["Avg post time (ms)"] == 1500.0 + assert cli_stats["Avg number of descriptors"] == 1.5 def test_multi_kv_connector_stats_aggregation(): @@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation(): from dataclasses import dataclass + # Mock a KVConnectorStats class for testing aggregation over connectors. @dataclass class FooKVConnectorStats(KVConnectorStats): @@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation(): if nixl_count > 0: nixl_stats = NixlKVConnectorStats() for _ in range(nixl_count): - nixl_stats.record_transfer() + nixl_stats.record_transfer(get_default_xfer_telemetry()) data["NixlConnector"] = nixl_stats if foo_count > 0: foo_stats = FooKVConnectorStats() @@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation(): assert isinstance(kv_connector_stats, MultiKVConnectorStats) # Validate per-connector totals across workers - assert kv_connector_stats["NixlConnector"].data[ - "num_successful_transfers"] == 5 + assert isinstance(kv_connector_stats["NixlConnector"], + NixlKVConnectorStats) + assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5 + assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats) assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 @@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): "working_dir": working_dir, # ship fake nixl package "env_vars": { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + # TODO: for ray to carry over, remove once we set + "NIXL_TELEMETRY_ENABLE": "1", }, } ray.init(runtime_env=runtime_env) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4706c5130899..fdfcc39666ad 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -4,6 +4,7 @@ import contextlib import copy import logging import math +import os import queue import threading import time @@ -54,10 +55,12 @@ logger = init_logger(__name__) # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: from nixl._api import nixl_agent as NixlWrapper + from nixl._bindings import nixlXferTelemetry logger.info("NIXL is available") except ImportError: logger.warning("NIXL is not available") NixlWrapper = None + nixlXferTelemetry = None try: from nixl._api import nixl_agent_config @@ -476,6 +479,9 @@ class NixlConnectorWorker: self.nixl_backends = \ vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"]) + # TODO temporary, once nixl allows for telemetry flag in config + # (next release), we can remove this env var. + os.environ["NIXL_TELEMETRY_ENABLE"] = "1" # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] if nixl_agent_config is None: @@ -1175,9 +1181,10 @@ class NixlConnectorWorker: for handle, _xfer_stime in handles: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": + # Get telemetry from NIXL + res = self.nixl_wrapper.get_xfer_telemetry(handle) + self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) - # TODO (NickLucche) Get from NIXL telemetry once integrated - self.xfer_stats.record_transfer() elif xfer_state == "PROC": in_progress = True continue @@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats): """Container for transfer performance metrics""" def __post_init__(self): - if "num_successful_transfers" not in self.data: - self.data["num_successful_transfers"] = 0 + if not self.data: + # Empty container init, no data is passed in. + self.reset() def reset(self): - self.data = {"num_successful_transfers": 0} + # Must be serializable + self.data: dict[str, list[float]] = { + "transfer_duration": [], + "post_duration": [], + "bytes_transferred": [], + "num_descriptors": [], + } - def record_transfer(self): - # TODO: record actual transfer stats when available - self.data["num_successful_transfers"] += 1 + def record_transfer(self, res: nixlXferTelemetry): + # Keep metrics units consistent with rest of the code: time us->s + self.data["transfer_duration"].append(res.xferDuration / 1e6) + self.data["post_duration"].append(res.postDuration / 1e6) + self.data["bytes_transferred"].append(res.totalBytes) + self.data["num_descriptors"].append(res.descCount) def clone_and_reset(self) -> "NixlKVConnectorStats": old = copy.copy(self) @@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats): return old def is_empty(self) -> bool: - return self.data["num_successful_transfers"] == 0 + return self.num_successful_transfers == 0 def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: if not other.is_empty(): - self.data["num_successful_transfers"] += other.data[ - "num_successful_transfers"] + for k, v in other.data.items(): + accumulator = self.data[k] + assert isinstance(accumulator, list) + accumulator.extend(v) return self def reduce(self) -> dict[str, Union[int, float]]: - # TODO: reduce stats to a single value, calculate latency/throughput + # Compute compact representative stats suitable for CLI logging + if self.is_empty(): + return { + "Num successful transfers": 0, + "Avg xfer time (ms)": 0, + "P90 xfer time (ms)": 0, + "Avg post time (ms)": 0, + "P90 post time (ms)": 0, + "Avg MB per transfer": 0, + "Throughput (MB/s)": 0, + "Avg number of descriptors": 0, + } + + xfer_time = np.asarray(self.data["transfer_duration"]) + post_time = np.asarray(self.data["post_duration"]) + # Convert to MB for CLI logging. + mb = np.asarray(self.data["bytes_transferred"]) / 2**20 + descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32) + n = len(descs) + assert n == self.num_successful_transfers + + total_mb = mb.sum() + avg_mb = total_mb / n + + total_time_seconds = xfer_time.sum() + throughput_mb_s = total_mb / total_time_seconds + return { - "num_successful_transfers": self.data["num_successful_transfers"] + "Num successful transfers": n, + "Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3), + "P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3), + "Avg post time (ms)": round(post_time.mean() * 1e3, 3), + "P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3), + "Avg MB per transfer": round(avg_mb, 3), + "Throughput (MB/s)": round(throughput_mb_s, 3), + "Avg number of descriptors": round(descs.mean(), 1), } + + @property + def num_successful_transfers(self) -> int: + return len(self.data["transfer_duration"]) \ No newline at end of file diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index d68d111c67ca..ef95f03e8882 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase): self.prefix_caching_metrics = PrefixCachingMetrics() self.spec_decoding_logging = SpecDecodingLogging() kv_tranfer_config = self.vllm_config.kv_transfer_config - self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config) + self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 @@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase): self.spec_decoding_logging.observe( scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: - self.kv_transfer_logging.observe(kv_connector_stats) + self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats def log(self): @@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase): self.prefix_caching_metrics.hit_rate * 100, ) self.spec_decoding_logging.log(log_fn=log_fn) - self.kv_transfer_logging.log(log_fn=log_fn) + self.kv_connector_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: