diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 74ae3ca9a8633..1c1ac915c758e 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -4,9 +4,22 @@ import filecmp import shutil import tempfile from pathlib import Path +from typing import Any + +import pytest from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiConnector, + MultiKVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlKVConnectorStats, +) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -19,6 +32,27 @@ PROMPTS = [ SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) +# Test connector with custom stats for testing MultiConnector +class MockConnectorStats(KVConnectorStats): + """Mock stats class for testing.""" + + pass + + +class MockConnector(KVConnectorBase_V1): + """Mock connector that implements build_kv_connector_stats for testing.""" + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> KVConnectorStats | None: + return MockConnectorStats(data=data) if data is not None else None + + +# Register the mock connector +KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) + + # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -225,3 +259,337 @@ def test_engine_id_conflict(): assert ids[0] != ids[1], ( f"Engine IDs should be different for different configs. Got {ids}" ) + + +class TestMultiConnectorStats: + """Tests for MultiConnector stats reconstruction and operations.""" + + def test_build_kv_connector_stats_with_none(self): + """Test that build_kv_connector_stats returns empty stats when given None.""" + stats = MultiConnector.build_kv_connector_stats(data=None) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 0 + assert stats.is_empty() + + def test_build_kv_connector_stats_with_empty_dict(self): + """Test that build_kv_connector_stats returns empty stats with empty dict.""" + stats = MultiConnector.build_kv_connector_stats(data={}) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 0 + assert stats.is_empty() + + def test_build_kv_connector_stats_reconstructs_nixl_stats(self): + """Test that NixlConnector stats are properly reconstructed with + correct data.""" + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5, 2.3], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + } + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert "NixlConnector" in stats.data + nixl_stats = stats.data["NixlConnector"] + assert isinstance(nixl_stats, NixlKVConnectorStats) + assert nixl_stats.data["transfer_duration"] == [1.5, 2.3] + assert nixl_stats.data["post_duration"] == [0.1, 0.2] + assert nixl_stats.data["bytes_transferred"] == [1024, 2048] + assert nixl_stats.data["num_descriptors"] == [10, 20] + + def test_build_kv_connector_stats_with_multiple_connectors(self): + """Test reconstruction with multiple connector types that have custom stats.""" + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + "MockConnector": {"data": {"mock_field": [1, 2, 3]}}, + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + # Both connectors should be reconstructed + assert len(stats.data) == 2 + assert "NixlConnector" in stats.data + assert "MockConnector" in stats.data + assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) + assert isinstance(stats.data["MockConnector"], MockConnectorStats) + # Verify data is preserved + assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]} + + def test_build_kv_connector_stats_raises_error_for_unknown_connector(self): + """Test that unknown connectors raise an error.""" + serialized_data = { + "UnknownConnector": {"data": {"some_field": [1, 2, 3]}}, + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + } + + with pytest.raises( + ValueError, match="Connector 'UnknownConnector' is not registered." + ): + MultiConnector.build_kv_connector_stats(data=serialized_data) + + def test_build_kv_connector_stats_with_already_instantiated_objects(self): + """Test that already-instantiated stats objects are preserved (same process).""" + # This simulates the in-process case where stats are not serialized + nixl_stats = NixlKVConnectorStats( + data={ + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]}) + + data_with_objects = { + "NixlConnector": nixl_stats, + "MockConnector": mock_stats, + } + + stats = MultiConnector.build_kv_connector_stats(data=data_with_objects) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 2 + # Verify objects are preserved as-is + assert stats.data["NixlConnector"] is nixl_stats + assert stats.data["MockConnector"] is mock_stats + + def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self): + """Test handling mixed already-instantiated and serialized stats.""" + # This can happen during transition or partial serialization + nixl_stats = NixlKVConnectorStats( + data={ + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + + mixed_data = { + "NixlConnector": nixl_stats, # Already instantiated + "MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized + } + + stats = MultiConnector.build_kv_connector_stats(data=mixed_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 2 + # Instantiated object preserved + assert stats.data["NixlConnector"] is nixl_stats + # Serialized object reconstructed + assert isinstance(stats.data["MockConnector"], MockConnectorStats) + assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]} + + def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self): + """Test that connectors without custom stats (return None) are skipped.""" + # SharedStorageConnector doesn't override build_kv_connector_stats, + # so it returns None and should be skipped + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + "SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}}, + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + # Only NixlConnector should be reconstructed + assert len(stats.data) == 1 + assert "NixlConnector" in stats.data + assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) + # SharedStorageConnector should be skipped (returns None) + assert "SharedStorageConnector" not in stats.data + + def test_build_kv_connector_stats_handles_malformed_data(self): + """Test that malformed data raises appropriate errors.""" + serialized_data = { + "NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}} + } + + with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"): + MultiConnector.build_kv_connector_stats(data=serialized_data) + + def test_aggregate_same_connector(self): + """Test aggregating stats from the same connector type.""" + stats1 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + stats2 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [2.0], + "post_duration": [0.2], + "bytes_transferred": [2048], + "num_descriptors": [20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + result = stats1.aggregate(stats2) + + assert result is stats1 # Should return self + assert "NixlConnector" in result.data + nixl_stats = result.data["NixlConnector"] + assert nixl_stats.data["transfer_duration"] == [1.0, 2.0] + assert nixl_stats.data["post_duration"] == [0.1, 0.2] + assert nixl_stats.data["bytes_transferred"] == [1024, 2048] + assert nixl_stats.data["num_descriptors"] == [10, 20] + + def test_aggregate_new_connector(self): + """Test aggregating stats when a new connector type appears.""" + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats, + ) + + stats1 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + stats2 = MultiKVConnectorStats( + data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})} + ) + + result = stats1.aggregate(stats2) + + assert "NixlConnector" in result.data + assert "SharedStorageConnector" in result.data + + def test_reduce(self): + """Test that reduce() correctly reduces all nested connector stats.""" + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0, 2.0], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + reduced = stats.reduce() + + assert "NixlConnector" in reduced + assert isinstance(reduced["NixlConnector"], dict) + # Check that the stats were reduced (should have aggregated values) + assert "Num successful transfers" in reduced["NixlConnector"] + assert reduced["NixlConnector"]["Num successful transfers"] == 2 + + def test_reset(self): + """Test that reset() resets all nested connector stats.""" + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0, 2.0], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + assert not stats.is_empty() + + stats.reset() + + # After reset, stats should be empty + assert stats.is_empty() + nixl_stats = stats.data["NixlConnector"] + assert len(nixl_stats.data["transfer_duration"]) == 0 + + def test_is_empty_with_multiple_connectors(self): + """Test is_empty() returns correct value with multiple connectors.""" + # All empty + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats(data={}), + } + ) + # Initialize empty stats + stats.data["NixlConnector"].reset() + assert stats.is_empty() + + # One non-empty + stats.data["NixlConnector"].data["transfer_duration"].append(1.0) + assert not stats.is_empty() diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 5ef56f6c381f4..46a9ce77f8c4c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -66,6 +66,24 @@ class KVConnectorFactory: # We build separately to enforce strict separation return connector_cls(config, role) + @classmethod + def get_connector_class_by_name( + cls, connector_name: str + ) -> type[KVConnectorBaseType]: + """Get a registered connector class by name. + + Raises ValueError if the connector is not registered. + + Args: + connector_name: Name of the registered connector. + + Returns: + The connector class. + """ + if connector_name not in cls._registry: + raise ValueError(f"Connector '{connector_name}' is not registered.") + return cls._registry[connector_name]() + @classmethod def get_connector_class( cls, kv_transfer_config: "KVTransferConfig" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 845ce320837d7..c1a2ac012415a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -324,11 +324,41 @@ class MultiConnector(KVConnectorBase_V1): def build_kv_connector_stats( cls, data: dict[str, Any] | None = None ) -> KVConnectorStats | None: - return ( - MultiKVConnectorStats(data=data) - if data is not None - else MultiKVConnectorStats() - ) + if data is None: + return MultiKVConnectorStats() + + # data is a dict mapping connector name to their stats data. + # The stats data can be either: + # 1. Already-instantiated KVConnectorStats objects (same process) + # 2. Serialized dicts (cross-process after serialization) + # We need to reconstruct proper KVConnectorStats objects from dicts + reconstructed_data = {} + for connector_name, stats_value in data.items(): + # If already a KVConnectorStats object, use it directly + if isinstance(stats_value, KVConnectorStats): + reconstructed_data[connector_name] = stats_value + continue + + # Otherwise, reconstruct from serialized dict + # Get the connector class to reconstruct its stats + connector_cls = KVConnectorFactory.get_connector_class_by_name( + connector_name + ) + + # stats_value is the serialized dataclass which contains {'data': {...}} + # We need to extract the inner 'data' field to avoid double-nesting + assert isinstance(stats_value, dict) and "data" in stats_value, ( + f"Expected a dict with a 'data' field, got {stats_value}" + ) + inner_data = stats_value["data"] + + # Use the connector's build_kv_connector_stats to reconstruct + if reconstructed_stats := connector_cls.build_kv_connector_stats( + data=inner_data + ): + reconstructed_data[connector_name] = reconstructed_stats + + return MultiKVConnectorStats(data=reconstructed_data) def get_kv_connector_stats(self) -> MultiKVConnectorStats | None: # Group connector stats by connector type.