mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 02:55:35 +08:00
[Bugfix] Fix MultiConnector stats reconstruction across process boundaries (#27366)
Signed-off-by: Kourosh Hakhamaneshi <Kourosh@anyscale.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
parent
699d62e6cf
commit
7e1d697b56
@ -4,9 +4,22 @@ import filecmp
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||||
|
MultiConnector,
|
||||||
|
MultiKVConnectorStats,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||||
|
NixlKVConnectorStats,
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
@ -19,6 +32,27 @@ PROMPTS = [
|
|||||||
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
||||||
|
|
||||||
|
|
||||||
|
# Test connector with custom stats for testing MultiConnector
|
||||||
|
class MockConnectorStats(KVConnectorStats):
|
||||||
|
"""Mock stats class for testing."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockConnector(KVConnectorBase_V1):
|
||||||
|
"""Mock connector that implements build_kv_connector_stats for testing."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_kv_connector_stats(
|
||||||
|
cls, data: dict[str, Any] | None = None
|
||||||
|
) -> KVConnectorStats | None:
|
||||||
|
return MockConnectorStats(data=data) if data is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# Register the mock connector
|
||||||
|
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||||
|
|
||||||
|
|
||||||
# Helper function to compare directories recursively
|
# Helper function to compare directories recursively
|
||||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||||
"""Compares two directories recursively for identical content."""
|
"""Compares two directories recursively for identical content."""
|
||||||
@ -225,3 +259,337 @@ def test_engine_id_conflict():
|
|||||||
assert ids[0] != ids[1], (
|
assert ids[0] != ids[1], (
|
||||||
f"Engine IDs should be different for different configs. Got {ids}"
|
f"Engine IDs should be different for different configs. Got {ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiConnectorStats:
|
||||||
|
"""Tests for MultiConnector stats reconstruction and operations."""
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_with_none(self):
|
||||||
|
"""Test that build_kv_connector_stats returns empty stats when given None."""
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=None)
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
assert len(stats.data) == 0
|
||||||
|
assert stats.is_empty()
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_with_empty_dict(self):
|
||||||
|
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data={})
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
assert len(stats.data) == 0
|
||||||
|
assert stats.is_empty()
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_reconstructs_nixl_stats(self):
|
||||||
|
"""Test that NixlConnector stats are properly reconstructed with
|
||||||
|
correct data."""
|
||||||
|
serialized_data = {
|
||||||
|
"NixlConnector": {
|
||||||
|
"data": {
|
||||||
|
"transfer_duration": [1.5, 2.3],
|
||||||
|
"post_duration": [0.1, 0.2],
|
||||||
|
"bytes_transferred": [1024, 2048],
|
||||||
|
"num_descriptors": [10, 20],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||||
|
|
||||||
|
assert "NixlConnector" in stats.data
|
||||||
|
nixl_stats = stats.data["NixlConnector"]
|
||||||
|
assert isinstance(nixl_stats, NixlKVConnectorStats)
|
||||||
|
assert nixl_stats.data["transfer_duration"] == [1.5, 2.3]
|
||||||
|
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||||
|
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||||
|
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_with_multiple_connectors(self):
|
||||||
|
"""Test reconstruction with multiple connector types that have custom stats."""
|
||||||
|
serialized_data = {
|
||||||
|
"NixlConnector": {
|
||||||
|
"data": {
|
||||||
|
"transfer_duration": [1.5],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"MockConnector": {"data": {"mock_field": [1, 2, 3]}},
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
# Both connectors should be reconstructed
|
||||||
|
assert len(stats.data) == 2
|
||||||
|
assert "NixlConnector" in stats.data
|
||||||
|
assert "MockConnector" in stats.data
|
||||||
|
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||||
|
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||||
|
# Verify data is preserved
|
||||||
|
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_raises_error_for_unknown_connector(self):
|
||||||
|
"""Test that unknown connectors raise an error."""
|
||||||
|
serialized_data = {
|
||||||
|
"UnknownConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||||
|
"NixlConnector": {
|
||||||
|
"data": {
|
||||||
|
"transfer_duration": [1.5],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Connector 'UnknownConnector' is not registered."
|
||||||
|
):
|
||||||
|
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_with_already_instantiated_objects(self):
|
||||||
|
"""Test that already-instantiated stats objects are preserved (same process)."""
|
||||||
|
# This simulates the in-process case where stats are not serialized
|
||||||
|
nixl_stats = NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.5],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]})
|
||||||
|
|
||||||
|
data_with_objects = {
|
||||||
|
"NixlConnector": nixl_stats,
|
||||||
|
"MockConnector": mock_stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=data_with_objects)
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
assert len(stats.data) == 2
|
||||||
|
# Verify objects are preserved as-is
|
||||||
|
assert stats.data["NixlConnector"] is nixl_stats
|
||||||
|
assert stats.data["MockConnector"] is mock_stats
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self):
|
||||||
|
"""Test handling mixed already-instantiated and serialized stats."""
|
||||||
|
# This can happen during transition or partial serialization
|
||||||
|
nixl_stats = NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.5],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mixed_data = {
|
||||||
|
"NixlConnector": nixl_stats, # Already instantiated
|
||||||
|
"MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=mixed_data)
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
assert len(stats.data) == 2
|
||||||
|
# Instantiated object preserved
|
||||||
|
assert stats.data["NixlConnector"] is nixl_stats
|
||||||
|
# Serialized object reconstructed
|
||||||
|
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||||
|
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
|
||||||
|
"""Test that connectors without custom stats (return None) are skipped."""
|
||||||
|
# SharedStorageConnector doesn't override build_kv_connector_stats,
|
||||||
|
# so it returns None and should be skipped
|
||||||
|
serialized_data = {
|
||||||
|
"NixlConnector": {
|
||||||
|
"data": {
|
||||||
|
"transfer_duration": [1.5],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||||
|
}
|
||||||
|
|
||||||
|
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||||
|
|
||||||
|
assert stats is not None
|
||||||
|
assert isinstance(stats, MultiKVConnectorStats)
|
||||||
|
# Only NixlConnector should be reconstructed
|
||||||
|
assert len(stats.data) == 1
|
||||||
|
assert "NixlConnector" in stats.data
|
||||||
|
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||||
|
# SharedStorageConnector should be skipped (returns None)
|
||||||
|
assert "SharedStorageConnector" not in stats.data
|
||||||
|
|
||||||
|
def test_build_kv_connector_stats_handles_malformed_data(self):
|
||||||
|
"""Test that malformed data raises appropriate errors."""
|
||||||
|
serialized_data = {
|
||||||
|
"NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"):
|
||||||
|
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||||
|
|
||||||
|
def test_aggregate_same_connector(self):
|
||||||
|
"""Test aggregating stats from the same connector type."""
|
||||||
|
stats1 = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.0],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stats2 = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [2.0],
|
||||||
|
"post_duration": [0.2],
|
||||||
|
"bytes_transferred": [2048],
|
||||||
|
"num_descriptors": [20],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = stats1.aggregate(stats2)
|
||||||
|
|
||||||
|
assert result is stats1 # Should return self
|
||||||
|
assert "NixlConnector" in result.data
|
||||||
|
nixl_stats = result.data["NixlConnector"]
|
||||||
|
assert nixl_stats.data["transfer_duration"] == [1.0, 2.0]
|
||||||
|
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||||
|
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||||
|
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||||
|
|
||||||
|
def test_aggregate_new_connector(self):
|
||||||
|
"""Test aggregating stats when a new connector type appears."""
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
|
KVConnectorStats,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats1 = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.0],
|
||||||
|
"post_duration": [0.1],
|
||||||
|
"bytes_transferred": [1024],
|
||||||
|
"num_descriptors": [10],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stats2 = MultiKVConnectorStats(
|
||||||
|
data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = stats1.aggregate(stats2)
|
||||||
|
|
||||||
|
assert "NixlConnector" in result.data
|
||||||
|
assert "SharedStorageConnector" in result.data
|
||||||
|
|
||||||
|
def test_reduce(self):
|
||||||
|
"""Test that reduce() correctly reduces all nested connector stats."""
|
||||||
|
stats = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.0, 2.0],
|
||||||
|
"post_duration": [0.1, 0.2],
|
||||||
|
"bytes_transferred": [1024, 2048],
|
||||||
|
"num_descriptors": [10, 20],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
reduced = stats.reduce()
|
||||||
|
|
||||||
|
assert "NixlConnector" in reduced
|
||||||
|
assert isinstance(reduced["NixlConnector"], dict)
|
||||||
|
# Check that the stats were reduced (should have aggregated values)
|
||||||
|
assert "Num successful transfers" in reduced["NixlConnector"]
|
||||||
|
assert reduced["NixlConnector"]["Num successful transfers"] == 2
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
"""Test that reset() resets all nested connector stats."""
|
||||||
|
stats = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"transfer_duration": [1.0, 2.0],
|
||||||
|
"post_duration": [0.1, 0.2],
|
||||||
|
"bytes_transferred": [1024, 2048],
|
||||||
|
"num_descriptors": [10, 20],
|
||||||
|
"num_failed_transfers": [],
|
||||||
|
"num_failed_notifications": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not stats.is_empty()
|
||||||
|
|
||||||
|
stats.reset()
|
||||||
|
|
||||||
|
# After reset, stats should be empty
|
||||||
|
assert stats.is_empty()
|
||||||
|
nixl_stats = stats.data["NixlConnector"]
|
||||||
|
assert len(nixl_stats.data["transfer_duration"]) == 0
|
||||||
|
|
||||||
|
def test_is_empty_with_multiple_connectors(self):
|
||||||
|
"""Test is_empty() returns correct value with multiple connectors."""
|
||||||
|
# All empty
|
||||||
|
stats = MultiKVConnectorStats(
|
||||||
|
data={
|
||||||
|
"NixlConnector": NixlKVConnectorStats(data={}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Initialize empty stats
|
||||||
|
stats.data["NixlConnector"].reset()
|
||||||
|
assert stats.is_empty()
|
||||||
|
|
||||||
|
# One non-empty
|
||||||
|
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
|
||||||
|
assert not stats.is_empty()
|
||||||
|
|||||||
@ -66,6 +66,24 @@ class KVConnectorFactory:
|
|||||||
# We build separately to enforce strict separation
|
# We build separately to enforce strict separation
|
||||||
return connector_cls(config, role)
|
return connector_cls(config, role)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connector_class_by_name(
|
||||||
|
cls, connector_name: str
|
||||||
|
) -> type[KVConnectorBaseType]:
|
||||||
|
"""Get a registered connector class by name.
|
||||||
|
|
||||||
|
Raises ValueError if the connector is not registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_name: Name of the registered connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The connector class.
|
||||||
|
"""
|
||||||
|
if connector_name not in cls._registry:
|
||||||
|
raise ValueError(f"Connector '{connector_name}' is not registered.")
|
||||||
|
return cls._registry[connector_name]()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_connector_class(
|
def get_connector_class(
|
||||||
cls, kv_transfer_config: "KVTransferConfig"
|
cls, kv_transfer_config: "KVTransferConfig"
|
||||||
|
|||||||
@ -324,11 +324,41 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
def build_kv_connector_stats(
|
def build_kv_connector_stats(
|
||||||
cls, data: dict[str, Any] | None = None
|
cls, data: dict[str, Any] | None = None
|
||||||
) -> KVConnectorStats | None:
|
) -> KVConnectorStats | None:
|
||||||
return (
|
if data is None:
|
||||||
MultiKVConnectorStats(data=data)
|
return MultiKVConnectorStats()
|
||||||
if data is not None
|
|
||||||
else MultiKVConnectorStats()
|
# data is a dict mapping connector name to their stats data.
|
||||||
)
|
# The stats data can be either:
|
||||||
|
# 1. Already-instantiated KVConnectorStats objects (same process)
|
||||||
|
# 2. Serialized dicts (cross-process after serialization)
|
||||||
|
# We need to reconstruct proper KVConnectorStats objects from dicts
|
||||||
|
reconstructed_data = {}
|
||||||
|
for connector_name, stats_value in data.items():
|
||||||
|
# If already a KVConnectorStats object, use it directly
|
||||||
|
if isinstance(stats_value, KVConnectorStats):
|
||||||
|
reconstructed_data[connector_name] = stats_value
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Otherwise, reconstruct from serialized dict
|
||||||
|
# Get the connector class to reconstruct its stats
|
||||||
|
connector_cls = KVConnectorFactory.get_connector_class_by_name(
|
||||||
|
connector_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# stats_value is the serialized dataclass which contains {'data': {...}}
|
||||||
|
# We need to extract the inner 'data' field to avoid double-nesting
|
||||||
|
assert isinstance(stats_value, dict) and "data" in stats_value, (
|
||||||
|
f"Expected a dict with a 'data' field, got {stats_value}"
|
||||||
|
)
|
||||||
|
inner_data = stats_value["data"]
|
||||||
|
|
||||||
|
# Use the connector's build_kv_connector_stats to reconstruct
|
||||||
|
if reconstructed_stats := connector_cls.build_kv_connector_stats(
|
||||||
|
data=inner_data
|
||||||
|
):
|
||||||
|
reconstructed_data[connector_name] = reconstructed_stats
|
||||||
|
|
||||||
|
return MultiKVConnectorStats(data=reconstructed_data)
|
||||||
|
|
||||||
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
|
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
|
||||||
# Group connector stats by connector type.
|
# Group connector stats by connector type.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user