mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:55:27 +08:00
[NIXL][Misc] Expose metrics from NIXL for logging to CLI (#25388)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
0e93ac0b3a
commit
48f309029a
@ -1,2 +1,2 @@
|
||||
lmcache
|
||||
nixl >= 0.5.1 # Required for disaggregated prefill
|
||||
nixl >= 0.6.0 # Required for disaggregated prefill
|
||||
|
||||
@ -57,6 +57,26 @@ def clear_kv_transfer():
|
||||
ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
def get_default_xfer_telemetry(xferDurationS: float = 1,
|
||||
postDurationS: float = 1,
|
||||
totalBytes: int = 1,
|
||||
descCount: int = 1) -> dict:
|
||||
|
||||
class AttributeDict(dict):
|
||||
__slots__ = ()
|
||||
__getattr__ = dict.__getitem__
|
||||
__setattr__ = dict.__setitem__ # type: ignore[assignment]
|
||||
|
||||
# We can't instantiate nixlXferTelemetry because it's read only and
|
||||
# ray env does not have NIXL, so we must fake it
|
||||
return AttributeDict(
|
||||
xferDuration=xferDurationS * 1e6, # in us
|
||||
postDuration=postDurationS * 1e6, # in us
|
||||
totalBytes=totalBytes,
|
||||
descCount=descCount,
|
||||
)
|
||||
|
||||
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
@ -132,6 +152,9 @@ class FakeNixlWrapper:
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
def get_xfer_telemetry(self, handle: int) -> dict:
|
||||
return get_default_xfer_telemetry()
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
|
||||
# Mock nixlXferTelemetry class
|
||||
pkg_root2 = os.path.join(td, "nixl", "_bindings")
|
||||
os.makedirs(pkg_root2, exist_ok=True)
|
||||
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
|
||||
f.write("class nixlXferTelemetry: pass")
|
||||
# touch parent package
|
||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||
yield td
|
||||
@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
|
||||
|
||||
# Verify stats values are recorded
|
||||
assert not stats_after_transfer.is_empty()
|
||||
assert stats_after_transfer.data["num_successful_transfers"] == 1
|
||||
assert stats_after_transfer.num_successful_transfers == 1
|
||||
|
||||
# Verify stats are reset after retrieval
|
||||
stats_after_reset = connector.get_kv_connector_stats()
|
||||
@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
|
||||
|
||||
# Record different transfers on each worker
|
||||
# Worker 1: 2 transfers
|
||||
worker1_stats.record_transfer()
|
||||
worker1_stats.record_transfer()
|
||||
stats = get_default_xfer_telemetry()
|
||||
worker1_stats.record_transfer(stats)
|
||||
worker1_stats.record_transfer(stats)
|
||||
|
||||
# Worker 2: 1 transfer
|
||||
worker2_stats.record_transfer()
|
||||
worker2_stats.record_transfer(stats)
|
||||
|
||||
# Worker 3: 3 transfers
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
stats = get_default_xfer_telemetry(xferDurationS=2,
|
||||
postDurationS=2,
|
||||
totalBytes=2,
|
||||
descCount=2)
|
||||
worker3_stats.record_transfer(stats)
|
||||
worker3_stats.record_transfer(stats)
|
||||
worker3_stats.record_transfer(stats)
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
||||
# Number of total transfers across all workers.
|
||||
assert kv_connector_stats.data["num_successful_transfers"] == 6
|
||||
assert kv_connector_stats.num_successful_transfers == 6
|
||||
# Logging proc, call reduce() to get CLI-friendly stats.
|
||||
cli_stats = kv_connector_stats.reduce()
|
||||
assert cli_stats["Avg xfer time (ms)"] == 1500.0
|
||||
assert cli_stats["Avg post time (ms)"] == 1500.0
|
||||
assert cli_stats["Avg number of descriptors"] == 1.5
|
||||
|
||||
|
||||
def test_multi_kv_connector_stats_aggregation():
|
||||
@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
# Mock a KVConnectorStats class for testing aggregation over connectors.
|
||||
@dataclass
|
||||
class FooKVConnectorStats(KVConnectorStats):
|
||||
|
||||
@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
if nixl_count > 0:
|
||||
nixl_stats = NixlKVConnectorStats()
|
||||
for _ in range(nixl_count):
|
||||
nixl_stats.record_transfer()
|
||||
nixl_stats.record_transfer(get_default_xfer_telemetry())
|
||||
data["NixlConnector"] = nixl_stats
|
||||
if foo_count > 0:
|
||||
foo_stats = FooKVConnectorStats()
|
||||
@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
||||
|
||||
# Validate per-connector totals across workers
|
||||
assert kv_connector_stats["NixlConnector"].data[
|
||||
"num_successful_transfers"] == 5
|
||||
assert isinstance(kv_connector_stats["NixlConnector"],
|
||||
NixlKVConnectorStats)
|
||||
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
|
||||
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
|
||||
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
||||
|
||||
|
||||
@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
"working_dir": working_dir, # ship fake nixl package
|
||||
"env_vars": {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||
# TODO: for ray to carry over, remove once we set
|
||||
"NIXL_TELEMETRY_ENABLE": "1",
|
||||
},
|
||||
}
|
||||
ray.init(runtime_env=runtime_env)
|
||||
|
||||
@ -4,6 +4,7 @@ import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
@ -54,10 +55,12 @@ logger = init_logger(__name__)
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
from nixl._bindings import nixlXferTelemetry
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
nixlXferTelemetry = None
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent_config
|
||||
@ -476,6 +479,9 @@ class NixlConnectorWorker:
|
||||
self.nixl_backends = \
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"])
|
||||
# TODO temporary, once nixl allows for telemetry flag in config
|
||||
# (next release), we can remove this env var.
|
||||
os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
if nixl_agent_config is None:
|
||||
@ -1175,9 +1181,10 @@ class NixlConnectorWorker:
|
||||
for handle, _xfer_stime in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
# Get telemetry from NIXL
|
||||
res = self.nixl_wrapper.get_xfer_telemetry(handle)
|
||||
self.xfer_stats.record_transfer(res)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
# TODO (NickLucche) Get from NIXL telemetry once integrated
|
||||
self.xfer_stats.record_transfer()
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
continue
|
||||
@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats):
|
||||
"""Container for transfer performance metrics"""
|
||||
|
||||
def __post_init__(self):
|
||||
if "num_successful_transfers" not in self.data:
|
||||
self.data["num_successful_transfers"] = 0
|
||||
if not self.data:
|
||||
# Empty container init, no data is passed in.
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_successful_transfers": 0}
|
||||
# Must be serializable
|
||||
self.data: dict[str, list[float]] = {
|
||||
"transfer_duration": [],
|
||||
"post_duration": [],
|
||||
"bytes_transferred": [],
|
||||
"num_descriptors": [],
|
||||
}
|
||||
|
||||
def record_transfer(self):
|
||||
# TODO: record actual transfer stats when available
|
||||
self.data["num_successful_transfers"] += 1
|
||||
def record_transfer(self, res: nixlXferTelemetry):
|
||||
# Keep metrics units consistent with rest of the code: time us->s
|
||||
self.data["transfer_duration"].append(res.xferDuration / 1e6)
|
||||
self.data["post_duration"].append(res.postDuration / 1e6)
|
||||
self.data["bytes_transferred"].append(res.totalBytes)
|
||||
self.data["num_descriptors"].append(res.descCount)
|
||||
|
||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||
old = copy.copy(self)
|
||||
@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats):
|
||||
return old
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_successful_transfers"] == 0
|
||||
return self.num_successful_transfers == 0
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
self.data["num_successful_transfers"] += other.data[
|
||||
"num_successful_transfers"]
|
||||
for k, v in other.data.items():
|
||||
accumulator = self.data[k]
|
||||
assert isinstance(accumulator, list)
|
||||
accumulator.extend(v)
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, Union[int, float]]:
|
||||
# TODO: reduce stats to a single value, calculate latency/throughput
|
||||
# Compute compact representative stats suitable for CLI logging
|
||||
if self.is_empty():
|
||||
return {
|
||||
"num_successful_transfers": self.data["num_successful_transfers"]
|
||||
"Num successful transfers": 0,
|
||||
"Avg xfer time (ms)": 0,
|
||||
"P90 xfer time (ms)": 0,
|
||||
"Avg post time (ms)": 0,
|
||||
"P90 post time (ms)": 0,
|
||||
"Avg MB per transfer": 0,
|
||||
"Throughput (MB/s)": 0,
|
||||
"Avg number of descriptors": 0,
|
||||
}
|
||||
|
||||
xfer_time = np.asarray(self.data["transfer_duration"])
|
||||
post_time = np.asarray(self.data["post_duration"])
|
||||
# Convert to MB for CLI logging.
|
||||
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
|
||||
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
|
||||
n = len(descs)
|
||||
assert n == self.num_successful_transfers
|
||||
|
||||
total_mb = mb.sum()
|
||||
avg_mb = total_mb / n
|
||||
|
||||
total_time_seconds = xfer_time.sum()
|
||||
throughput_mb_s = total_mb / total_time_seconds
|
||||
|
||||
return {
|
||||
"Num successful transfers": n,
|
||||
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
|
||||
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
|
||||
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
|
||||
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
|
||||
"Avg MB per transfer": round(avg_mb, 3),
|
||||
"Throughput (MB/s)": round(throughput_mb_s, 3),
|
||||
"Avg number of descriptors": round(descs.mean(), 1),
|
||||
}
|
||||
|
||||
@property
|
||||
def num_successful_transfers(self) -> int:
|
||||
return len(self.data["transfer_duration"])
|
||||
@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
||||
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
|
||||
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
|
||||
@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging.observe(
|
||||
scheduler_stats.spec_decoding_stats)
|
||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||
self.kv_transfer_logging.observe(kv_connector_stats)
|
||||
self.kv_connector_logging.observe(kv_connector_stats)
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
self.kv_transfer_logging.log(log_fn=log_fn)
|
||||
self.kv_connector_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user