mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:55:27 +08:00
[NIXL][Misc] Expose metrics from NIXL for logging to CLI (#25388)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
0e93ac0b3a
commit
48f309029a
@ -1,2 +1,2 @@
|
|||||||
lmcache
|
lmcache
|
||||||
nixl >= 0.5.1 # Required for disaggregated prefill
|
nixl >= 0.6.0 # Required for disaggregated prefill
|
||||||
|
|||||||
@ -57,6 +57,26 @@ def clear_kv_transfer():
|
|||||||
ensure_kv_transfer_shutdown()
|
ensure_kv_transfer_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_xfer_telemetry(xferDurationS: float = 1,
|
||||||
|
postDurationS: float = 1,
|
||||||
|
totalBytes: int = 1,
|
||||||
|
descCount: int = 1) -> dict:
|
||||||
|
|
||||||
|
class AttributeDict(dict):
|
||||||
|
__slots__ = ()
|
||||||
|
__getattr__ = dict.__getitem__
|
||||||
|
__setattr__ = dict.__setitem__ # type: ignore[assignment]
|
||||||
|
|
||||||
|
# We can't instantiate nixlXferTelemetry because it's read only and
|
||||||
|
# ray env does not have NIXL, so we must fake it
|
||||||
|
return AttributeDict(
|
||||||
|
xferDuration=xferDurationS * 1e6, # in us
|
||||||
|
postDuration=postDurationS * 1e6, # in us
|
||||||
|
totalBytes=totalBytes,
|
||||||
|
descCount=descCount,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlWrapper:
|
class FakeNixlWrapper:
|
||||||
"""Mock implementation of NixlWrapper for testing.
|
"""Mock implementation of NixlWrapper for testing.
|
||||||
|
|
||||||
@ -132,6 +152,9 @@ class FakeNixlWrapper:
|
|||||||
def transfer(self, handle: int) -> str:
|
def transfer(self, handle: int) -> str:
|
||||||
return "PROC"
|
return "PROC"
|
||||||
|
|
||||||
|
def get_xfer_telemetry(self, handle: int) -> dict:
|
||||||
|
return get_default_xfer_telemetry()
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Follow are for changing the behavior during testing.
|
# Follow are for changing the behavior during testing.
|
||||||
############################################################
|
############################################################
|
||||||
@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper
|
|||||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||||
f.write(stub)
|
f.write(stub)
|
||||||
|
|
||||||
|
# Mock nixlXferTelemetry class
|
||||||
|
pkg_root2 = os.path.join(td, "nixl", "_bindings")
|
||||||
|
os.makedirs(pkg_root2, exist_ok=True)
|
||||||
|
with open(os.path.join(pkg_root2, "__init__.py"), "w") as f:
|
||||||
|
f.write("class nixlXferTelemetry: pass")
|
||||||
# touch parent package
|
# touch parent package
|
||||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||||
yield td
|
yield td
|
||||||
@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
|
|||||||
|
|
||||||
# Verify stats values are recorded
|
# Verify stats values are recorded
|
||||||
assert not stats_after_transfer.is_empty()
|
assert not stats_after_transfer.is_empty()
|
||||||
assert stats_after_transfer.data["num_successful_transfers"] == 1
|
assert stats_after_transfer.num_successful_transfers == 1
|
||||||
|
|
||||||
# Verify stats are reset after retrieval
|
# Verify stats are reset after retrieval
|
||||||
stats_after_reset = connector.get_kv_connector_stats()
|
stats_after_reset = connector.get_kv_connector_stats()
|
||||||
@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
|
|||||||
|
|
||||||
# Record different transfers on each worker
|
# Record different transfers on each worker
|
||||||
# Worker 1: 2 transfers
|
# Worker 1: 2 transfers
|
||||||
worker1_stats.record_transfer()
|
stats = get_default_xfer_telemetry()
|
||||||
worker1_stats.record_transfer()
|
worker1_stats.record_transfer(stats)
|
||||||
|
worker1_stats.record_transfer(stats)
|
||||||
|
|
||||||
# Worker 2: 1 transfer
|
# Worker 2: 1 transfer
|
||||||
worker2_stats.record_transfer()
|
worker2_stats.record_transfer(stats)
|
||||||
|
|
||||||
# Worker 3: 3 transfers
|
# Worker 3: 3 transfers
|
||||||
worker3_stats.record_transfer()
|
stats = get_default_xfer_telemetry(xferDurationS=2,
|
||||||
worker3_stats.record_transfer()
|
postDurationS=2,
|
||||||
worker3_stats.record_transfer()
|
totalBytes=2,
|
||||||
|
descCount=2)
|
||||||
|
worker3_stats.record_transfer(stats)
|
||||||
|
worker3_stats.record_transfer(stats)
|
||||||
|
worker3_stats.record_transfer(stats)
|
||||||
|
|
||||||
# Create ModelRunnerOutput instances for each worker
|
# Create ModelRunnerOutput instances for each worker
|
||||||
worker_outputs = []
|
worker_outputs = []
|
||||||
@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
|
|||||||
aggregated_output.kv_connector_output.kv_connector_stats
|
aggregated_output.kv_connector_output.kv_connector_stats
|
||||||
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
||||||
# Number of total transfers across all workers.
|
# Number of total transfers across all workers.
|
||||||
assert kv_connector_stats.data["num_successful_transfers"] == 6
|
assert kv_connector_stats.num_successful_transfers == 6
|
||||||
|
# Logging proc, call reduce() to get CLI-friendly stats.
|
||||||
|
cli_stats = kv_connector_stats.reduce()
|
||||||
|
assert cli_stats["Avg xfer time (ms)"] == 1500.0
|
||||||
|
assert cli_stats["Avg post time (ms)"] == 1500.0
|
||||||
|
assert cli_stats["Avg number of descriptors"] == 1.5
|
||||||
|
|
||||||
|
|
||||||
def test_multi_kv_connector_stats_aggregation():
|
def test_multi_kv_connector_stats_aggregation():
|
||||||
@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# Mock a KVConnectorStats class for testing aggregation over connectors.
|
||||||
@dataclass
|
@dataclass
|
||||||
class FooKVConnectorStats(KVConnectorStats):
|
class FooKVConnectorStats(KVConnectorStats):
|
||||||
|
|
||||||
@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation():
|
|||||||
if nixl_count > 0:
|
if nixl_count > 0:
|
||||||
nixl_stats = NixlKVConnectorStats()
|
nixl_stats = NixlKVConnectorStats()
|
||||||
for _ in range(nixl_count):
|
for _ in range(nixl_count):
|
||||||
nixl_stats.record_transfer()
|
nixl_stats.record_transfer(get_default_xfer_telemetry())
|
||||||
data["NixlConnector"] = nixl_stats
|
data["NixlConnector"] = nixl_stats
|
||||||
if foo_count > 0:
|
if foo_count > 0:
|
||||||
foo_stats = FooKVConnectorStats()
|
foo_stats = FooKVConnectorStats()
|
||||||
@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation():
|
|||||||
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
||||||
|
|
||||||
# Validate per-connector totals across workers
|
# Validate per-connector totals across workers
|
||||||
assert kv_connector_stats["NixlConnector"].data[
|
assert isinstance(kv_connector_stats["NixlConnector"],
|
||||||
"num_successful_transfers"] == 5
|
NixlKVConnectorStats)
|
||||||
|
assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5
|
||||||
|
assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats)
|
||||||
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
||||||
|
|
||||||
|
|
||||||
@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
|||||||
"working_dir": working_dir, # ship fake nixl package
|
"working_dir": working_dir, # ship fake nixl package
|
||||||
"env_vars": {
|
"env_vars": {
|
||||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||||
|
# TODO: for ray to carry over, remove once we set
|
||||||
|
"NIXL_TELEMETRY_ENABLE": "1",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ray.init(runtime_env=runtime_env)
|
ray.init(runtime_env=runtime_env)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import contextlib
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -54,10 +55,12 @@ logger = init_logger(__name__)
|
|||||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||||
try:
|
try:
|
||||||
from nixl._api import nixl_agent as NixlWrapper
|
from nixl._api import nixl_agent as NixlWrapper
|
||||||
|
from nixl._bindings import nixlXferTelemetry
|
||||||
logger.info("NIXL is available")
|
logger.info("NIXL is available")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("NIXL is not available")
|
logger.warning("NIXL is not available")
|
||||||
NixlWrapper = None
|
NixlWrapper = None
|
||||||
|
nixlXferTelemetry = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from nixl._api import nixl_agent_config
|
from nixl._api import nixl_agent_config
|
||||||
@ -476,6 +479,9 @@ class NixlConnectorWorker:
|
|||||||
self.nixl_backends = \
|
self.nixl_backends = \
|
||||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
"backends", ["UCX"])
|
"backends", ["UCX"])
|
||||||
|
# TODO temporary, once nixl allows for telemetry flag in config
|
||||||
|
# (next release), we can remove this env var.
|
||||||
|
os.environ["NIXL_TELEMETRY_ENABLE"] = "1"
|
||||||
# Agent.
|
# Agent.
|
||||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||||
if nixl_agent_config is None:
|
if nixl_agent_config is None:
|
||||||
@ -1175,9 +1181,10 @@ class NixlConnectorWorker:
|
|||||||
for handle, _xfer_stime in handles:
|
for handle, _xfer_stime in handles:
|
||||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||||
if xfer_state == "DONE":
|
if xfer_state == "DONE":
|
||||||
|
# Get telemetry from NIXL
|
||||||
|
res = self.nixl_wrapper.get_xfer_telemetry(handle)
|
||||||
|
self.xfer_stats.record_transfer(res)
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
# TODO (NickLucche) Get from NIXL telemetry once integrated
|
|
||||||
self.xfer_stats.record_transfer()
|
|
||||||
elif xfer_state == "PROC":
|
elif xfer_state == "PROC":
|
||||||
in_progress = True
|
in_progress = True
|
||||||
continue
|
continue
|
||||||
@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats):
|
|||||||
"""Container for transfer performance metrics"""
|
"""Container for transfer performance metrics"""
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if "num_successful_transfers" not in self.data:
|
if not self.data:
|
||||||
self.data["num_successful_transfers"] = 0
|
# Empty container init, no data is passed in.
|
||||||
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.data = {"num_successful_transfers": 0}
|
# Must be serializable
|
||||||
|
self.data: dict[str, list[float]] = {
|
||||||
|
"transfer_duration": [],
|
||||||
|
"post_duration": [],
|
||||||
|
"bytes_transferred": [],
|
||||||
|
"num_descriptors": [],
|
||||||
|
}
|
||||||
|
|
||||||
def record_transfer(self):
|
def record_transfer(self, res: nixlXferTelemetry):
|
||||||
# TODO: record actual transfer stats when available
|
# Keep metrics units consistent with rest of the code: time us->s
|
||||||
self.data["num_successful_transfers"] += 1
|
self.data["transfer_duration"].append(res.xferDuration / 1e6)
|
||||||
|
self.data["post_duration"].append(res.postDuration / 1e6)
|
||||||
|
self.data["bytes_transferred"].append(res.totalBytes)
|
||||||
|
self.data["num_descriptors"].append(res.descCount)
|
||||||
|
|
||||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||||
old = copy.copy(self)
|
old = copy.copy(self)
|
||||||
@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats):
|
|||||||
return old
|
return old
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return self.data["num_successful_transfers"] == 0
|
return self.num_successful_transfers == 0
|
||||||
|
|
||||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||||
if not other.is_empty():
|
if not other.is_empty():
|
||||||
self.data["num_successful_transfers"] += other.data[
|
for k, v in other.data.items():
|
||||||
"num_successful_transfers"]
|
accumulator = self.data[k]
|
||||||
|
assert isinstance(accumulator, list)
|
||||||
|
accumulator.extend(v)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def reduce(self) -> dict[str, Union[int, float]]:
|
def reduce(self) -> dict[str, Union[int, float]]:
|
||||||
# TODO: reduce stats to a single value, calculate latency/throughput
|
# Compute compact representative stats suitable for CLI logging
|
||||||
|
if self.is_empty():
|
||||||
|
return {
|
||||||
|
"Num successful transfers": 0,
|
||||||
|
"Avg xfer time (ms)": 0,
|
||||||
|
"P90 xfer time (ms)": 0,
|
||||||
|
"Avg post time (ms)": 0,
|
||||||
|
"P90 post time (ms)": 0,
|
||||||
|
"Avg MB per transfer": 0,
|
||||||
|
"Throughput (MB/s)": 0,
|
||||||
|
"Avg number of descriptors": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
xfer_time = np.asarray(self.data["transfer_duration"])
|
||||||
|
post_time = np.asarray(self.data["post_duration"])
|
||||||
|
# Convert to MB for CLI logging.
|
||||||
|
mb = np.asarray(self.data["bytes_transferred"]) / 2**20
|
||||||
|
descs = np.asarray(self.data["num_descriptors"], dtype=np.uint32)
|
||||||
|
n = len(descs)
|
||||||
|
assert n == self.num_successful_transfers
|
||||||
|
|
||||||
|
total_mb = mb.sum()
|
||||||
|
avg_mb = total_mb / n
|
||||||
|
|
||||||
|
total_time_seconds = xfer_time.sum()
|
||||||
|
throughput_mb_s = total_mb / total_time_seconds
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"num_successful_transfers": self.data["num_successful_transfers"]
|
"Num successful transfers": n,
|
||||||
|
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
|
||||||
|
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
|
||||||
|
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
|
||||||
|
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
|
||||||
|
"Avg MB per transfer": round(avg_mb, 3),
|
||||||
|
"Throughput (MB/s)": round(throughput_mb_s, 3),
|
||||||
|
"Avg number of descriptors": round(descs.mean(), 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_successful_transfers(self) -> int:
|
||||||
|
return len(self.data["transfer_duration"])
|
||||||
@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||||
self.spec_decoding_logging = SpecDecodingLogging()
|
self.spec_decoding_logging = SpecDecodingLogging()
|
||||||
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
||||||
self.kv_transfer_logging = KVConnectorLogging(kv_tranfer_config)
|
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging.observe(
|
self.spec_decoding_logging.observe(
|
||||||
scheduler_stats.spec_decoding_stats)
|
scheduler_stats.spec_decoding_stats)
|
||||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||||
self.kv_transfer_logging.observe(kv_connector_stats)
|
self.kv_connector_logging.observe(kv_connector_stats)
|
||||||
self.last_scheduler_stats = scheduler_stats
|
self.last_scheduler_stats = scheduler_stats
|
||||||
|
|
||||||
def log(self):
|
def log(self):
|
||||||
@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.prefix_caching_metrics.hit_rate * 100,
|
self.prefix_caching_metrics.hit_rate * 100,
|
||||||
)
|
)
|
||||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
self.kv_transfer_logging.log(log_fn=log_fn)
|
self.kv_connector_logging.log(log_fn=log_fn)
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user