[Bugfix] Fix MultiConnector stats reconstruction across process boundaries (#27366)

Signed-off-by: Kourosh Hakhamaneshi <Kourosh@anyscale.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
kourosh hakhamaneshi 2025-10-24 10:08:05 -07:00 committed by GitHub
parent 699d62e6cf
commit 7e1d697b56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 421 additions and 5 deletions

View File

@ -4,9 +4,22 @@ import filecmp
import shutil import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any
import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -19,6 +32,27 @@ PROMPTS = [
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
# Test connector with custom stats for testing MultiConnector
class MockConnectorStats(KVConnectorStats):
"""Mock stats class for testing."""
pass
class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing."""
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None
# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
# Helper function to compare directories recursively # Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool: def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content.""" """Compares two directories recursively for identical content."""
@ -225,3 +259,337 @@ def test_engine_id_conflict():
assert ids[0] != ids[1], ( assert ids[0] != ids[1], (
f"Engine IDs should be different for different configs. Got {ids}" f"Engine IDs should be different for different configs. Got {ids}"
) )
class TestMultiConnectorStats:
"""Tests for MultiConnector stats reconstruction and operations."""
def test_build_kv_connector_stats_with_none(self):
"""Test that build_kv_connector_stats returns empty stats when given None."""
stats = MultiConnector.build_kv_connector_stats(data=None)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_with_empty_dict(self):
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
stats = MultiConnector.build_kv_connector_stats(data={})
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_reconstructs_nixl_stats(self):
"""Test that NixlConnector stats are properly reconstructed with
correct data."""
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5, 2.3],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
}
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert "NixlConnector" in stats.data
nixl_stats = stats.data["NixlConnector"]
assert isinstance(nixl_stats, NixlKVConnectorStats)
assert nixl_stats.data["transfer_duration"] == [1.5, 2.3]
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
assert nixl_stats.data["num_descriptors"] == [10, 20]
def test_build_kv_connector_stats_with_multiple_connectors(self):
"""Test reconstruction with multiple connector types that have custom stats."""
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
"MockConnector": {"data": {"mock_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
# Both connectors should be reconstructed
assert len(stats.data) == 2
assert "NixlConnector" in stats.data
assert "MockConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
# Verify data is preserved
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
def test_build_kv_connector_stats_raises_error_for_unknown_connector(self):
"""Test that unknown connectors raise an error."""
serialized_data = {
"UnknownConnector": {"data": {"some_field": [1, 2, 3]}},
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
}
with pytest.raises(
ValueError, match="Connector 'UnknownConnector' is not registered."
):
MultiConnector.build_kv_connector_stats(data=serialized_data)
def test_build_kv_connector_stats_with_already_instantiated_objects(self):
"""Test that already-instantiated stats objects are preserved (same process)."""
# This simulates the in-process case where stats are not serialized
nixl_stats = NixlKVConnectorStats(
data={
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]})
data_with_objects = {
"NixlConnector": nixl_stats,
"MockConnector": mock_stats,
}
stats = MultiConnector.build_kv_connector_stats(data=data_with_objects)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 2
# Verify objects are preserved as-is
assert stats.data["NixlConnector"] is nixl_stats
assert stats.data["MockConnector"] is mock_stats
def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self):
"""Test handling mixed already-instantiated and serialized stats."""
# This can happen during transition or partial serialization
nixl_stats = NixlKVConnectorStats(
data={
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
mixed_data = {
"NixlConnector": nixl_stats, # Already instantiated
"MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized
}
stats = MultiConnector.build_kv_connector_stats(data=mixed_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 2
# Instantiated object preserved
assert stats.data["NixlConnector"] is nixl_stats
# Serialized object reconstructed
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
"""Test that connectors without custom stats (return None) are skipped."""
# SharedStorageConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
"SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
# Only NixlConnector should be reconstructed
assert len(stats.data) == 1
assert "NixlConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
# SharedStorageConnector should be skipped (returns None)
assert "SharedStorageConnector" not in stats.data
def test_build_kv_connector_stats_handles_malformed_data(self):
"""Test that malformed data raises appropriate errors."""
serialized_data = {
"NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}}
}
with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"):
MultiConnector.build_kv_connector_stats(data=serialized_data)
def test_aggregate_same_connector(self):
"""Test aggregating stats from the same connector type."""
stats1 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
stats2 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [2.0],
"post_duration": [0.2],
"bytes_transferred": [2048],
"num_descriptors": [20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
result = stats1.aggregate(stats2)
assert result is stats1 # Should return self
assert "NixlConnector" in result.data
nixl_stats = result.data["NixlConnector"]
assert nixl_stats.data["transfer_duration"] == [1.0, 2.0]
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
assert nixl_stats.data["num_descriptors"] == [10, 20]
def test_aggregate_new_connector(self):
"""Test aggregating stats when a new connector type appears."""
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats,
)
stats1 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
stats2 = MultiKVConnectorStats(
data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})}
)
result = stats1.aggregate(stats2)
assert "NixlConnector" in result.data
assert "SharedStorageConnector" in result.data
def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats."""
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0, 2.0],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
reduced = stats.reduce()
assert "NixlConnector" in reduced
assert isinstance(reduced["NixlConnector"], dict)
# Check that the stats were reduced (should have aggregated values)
assert "Num successful transfers" in reduced["NixlConnector"]
assert reduced["NixlConnector"]["Num successful transfers"] == 2
def test_reset(self):
"""Test that reset() resets all nested connector stats."""
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0, 2.0],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
assert not stats.is_empty()
stats.reset()
# After reset, stats should be empty
assert stats.is_empty()
nixl_stats = stats.data["NixlConnector"]
assert len(nixl_stats.data["transfer_duration"]) == 0
def test_is_empty_with_multiple_connectors(self):
"""Test is_empty() returns correct value with multiple connectors."""
# All empty
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(data={}),
}
)
# Initialize empty stats
stats.data["NixlConnector"].reset()
assert stats.is_empty()
# One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty()

View File

@ -66,6 +66,24 @@ class KVConnectorFactory:
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role)
@classmethod
def get_connector_class_by_name(
cls, connector_name: str
) -> type[KVConnectorBaseType]:
"""Get a registered connector class by name.
Raises ValueError if the connector is not registered.
Args:
connector_name: Name of the registered connector.
Returns:
The connector class.
"""
if connector_name not in cls._registry:
raise ValueError(f"Connector '{connector_name}' is not registered.")
return cls._registry[connector_name]()
@classmethod @classmethod
def get_connector_class( def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig" cls, kv_transfer_config: "KVTransferConfig"

View File

@ -324,11 +324,41 @@ class MultiConnector(KVConnectorBase_V1):
def build_kv_connector_stats( def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None: ) -> KVConnectorStats | None:
return ( if data is None:
MultiKVConnectorStats(data=data) return MultiKVConnectorStats()
if data is not None
else MultiKVConnectorStats() # data is a dict mapping connector name to their stats data.
) # The stats data can be either:
# 1. Already-instantiated KVConnectorStats objects (same process)
# 2. Serialized dicts (cross-process after serialization)
# We need to reconstruct proper KVConnectorStats objects from dicts
reconstructed_data = {}
for connector_name, stats_value in data.items():
# If already a KVConnectorStats object, use it directly
if isinstance(stats_value, KVConnectorStats):
reconstructed_data[connector_name] = stats_value
continue
# Otherwise, reconstruct from serialized dict
# Get the connector class to reconstruct its stats
connector_cls = KVConnectorFactory.get_connector_class_by_name(
connector_name
)
# stats_value is the serialized dataclass which contains {'data': {...}}
# We need to extract the inner 'data' field to avoid double-nesting
assert isinstance(stats_value, dict) and "data" in stats_value, (
f"Expected a dict with a 'data' field, got {stats_value}"
)
inner_data = stats_value["data"]
# Use the connector's build_kv_connector_stats to reconstruct
if reconstructed_stats := connector_cls.build_kv_connector_stats(
data=inner_data
):
reconstructed_data[connector_name] = reconstructed_stats
return MultiKVConnectorStats(data=reconstructed_data)
def get_kv_connector_stats(self) -> MultiKVConnectorStats | None: def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
# Group connector stats by connector type. # Group connector stats by connector type.