mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 06:05:01 +08:00
[KVConnector] Add KV events to KV Connectors (#28309)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
parent
a11f4a81e0
commit
f4417f8449
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
@ -0,0 +1,756 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import BlockStored
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
|
||||||
|
LMCacheConnectorV1,
|
||||||
|
LMCacheKVEvents,
|
||||||
|
)
|
||||||
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_lmcache_engine_event():
|
||||||
|
"""Create a mock event object that mimics what the lmcache engine returns."""
|
||||||
|
|
||||||
|
class MockEvent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_hashes,
|
||||||
|
parent_block_hash,
|
||||||
|
token_ids,
|
||||||
|
lora_id,
|
||||||
|
block_size,
|
||||||
|
medium,
|
||||||
|
):
|
||||||
|
self.block_hashes = block_hashes
|
||||||
|
self.parent_block_hash = parent_block_hash
|
||||||
|
self.token_ids = token_ids
|
||||||
|
self.lora_id = lora_id
|
||||||
|
self.block_size = block_size
|
||||||
|
self.medium = medium
|
||||||
|
|
||||||
|
return MockEvent(
|
||||||
|
block_hashes=["hash1", "hash2"],
|
||||||
|
parent_block_hash="parent_hash",
|
||||||
|
token_ids=[1, 2, 3, 4],
|
||||||
|
lora_id=None,
|
||||||
|
block_size=16,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_connector():
|
||||||
|
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
|
||||||
|
connector = MagicMock(spec=LMCacheConnectorV1)
|
||||||
|
connector._kv_cache_events = None
|
||||||
|
connector._lmcache_engine = MagicMock()
|
||||||
|
|
||||||
|
# Make the methods use the real implementation
|
||||||
|
connector.get_kv_connector_kv_cache_events = (
|
||||||
|
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
|
||||||
|
connector, LMCacheConnectorV1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector.update_connector_output = (
|
||||||
|
LMCacheConnectorV1.update_connector_output.__get__(
|
||||||
|
connector, LMCacheConnectorV1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector.take_events = LMCacheConnectorV1.take_events.__get__(
|
||||||
|
connector, LMCacheConnectorV1
|
||||||
|
)
|
||||||
|
|
||||||
|
return connector
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetKVConnectorKVCacheEvents:
|
||||||
|
"""Test get_kv_connector_kv_cache_events method."""
|
||||||
|
|
||||||
|
def test_returns_none_when_no_events(self, mock_connector):
|
||||||
|
"""Test that None is returned when lmcache engine has no events."""
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
|
||||||
|
|
||||||
|
def test_returns_none_when_empty_list(self, mock_connector):
|
||||||
|
"""Test that None is returned when lmcache engine returns empty list."""
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = []
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
|
||||||
|
"""Test conversion of a single event from lmcache engine format."""
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
mock_lmcache_engine_event
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, LMCacheKVEvents)
|
||||||
|
assert result.get_number_of_workers() == 1
|
||||||
|
|
||||||
|
events = result.get_all_events()
|
||||||
|
assert len(events) == 1
|
||||||
|
assert isinstance(events[0], BlockStored)
|
||||||
|
assert events[0].block_hashes == ["hash1", "hash2"]
|
||||||
|
assert events[0].parent_block_hash == "parent_hash"
|
||||||
|
assert events[0].token_ids == [1, 2, 3, 4]
|
||||||
|
assert events[0].lora_id is None
|
||||||
|
assert events[0].block_size == 16
|
||||||
|
assert events[0].medium == "GPU"
|
||||||
|
|
||||||
|
def test_converts_multiple_events(self, mock_connector):
|
||||||
|
"""Test conversion of multiple events from lmcache engine format."""
|
||||||
|
|
||||||
|
class MockEvent:
|
||||||
|
def __init__(self, i):
|
||||||
|
self.block_hashes = [f"hash{i}"]
|
||||||
|
self.parent_block_hash = f"parent{i}"
|
||||||
|
self.token_ids = [i]
|
||||||
|
self.lora_id = None
|
||||||
|
self.block_size = 16
|
||||||
|
self.medium = "GPU"
|
||||||
|
|
||||||
|
events = [MockEvent(i) for i in range(5)]
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = events
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, LMCacheKVEvents)
|
||||||
|
|
||||||
|
converted_events = result.get_all_events()
|
||||||
|
assert len(converted_events) == 5
|
||||||
|
|
||||||
|
for i, event in enumerate(converted_events):
|
||||||
|
assert isinstance(event, BlockStored)
|
||||||
|
assert event.block_hashes == [f"hash{i}"]
|
||||||
|
assert event.parent_block_hash == f"parent{i}"
|
||||||
|
assert event.token_ids == [i]
|
||||||
|
|
||||||
|
def test_preserves_event_attributes(self, mock_connector):
|
||||||
|
"""Test that all event attributes are correctly preserved."""
|
||||||
|
|
||||||
|
class MockEventWithLora:
|
||||||
|
def __init__(self):
|
||||||
|
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
|
||||||
|
self.parent_block_hash = "parent_xyz"
|
||||||
|
self.token_ids = [100, 200, 300]
|
||||||
|
self.lora_id = 42
|
||||||
|
self.block_size = 32
|
||||||
|
self.medium = "DISK"
|
||||||
|
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
MockEventWithLora()
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
events = result.get_all_events()
|
||||||
|
event = events[0]
|
||||||
|
|
||||||
|
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
|
||||||
|
assert event.parent_block_hash == "parent_xyz"
|
||||||
|
assert event.token_ids == [100, 200, 300]
|
||||||
|
assert event.lora_id == 42
|
||||||
|
assert event.block_size == 32
|
||||||
|
assert event.medium == "DISK"
|
||||||
|
|
||||||
|
def test_handles_none_parent_block_hash(self, mock_connector):
|
||||||
|
"""Test handling of events with None parent_block_hash."""
|
||||||
|
|
||||||
|
class MockEventNoParent:
|
||||||
|
def __init__(self):
|
||||||
|
self.block_hashes = ["hash1"]
|
||||||
|
self.parent_block_hash = None
|
||||||
|
self.token_ids = [1, 2]
|
||||||
|
self.lora_id = None
|
||||||
|
self.block_size = 16
|
||||||
|
self.medium = "GPU"
|
||||||
|
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
MockEventNoParent()
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
events = result.get_all_events()
|
||||||
|
assert events[0].parent_block_hash is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateConnectorOutput:
|
||||||
|
"""Test update_connector_output method."""
|
||||||
|
|
||||||
|
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||||
|
"""Test that method returns early when kv_cache_events is None."""
|
||||||
|
connector_output = KVConnectorOutput(kv_cache_events=None)
|
||||||
|
|
||||||
|
mock_connector.update_connector_output(connector_output)
|
||||||
|
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
|
||||||
|
self, mock_connector
|
||||||
|
):
|
||||||
|
"""Test that method returns early when kv_cache_events is not
|
||||||
|
LMCacheKVEvents."""
|
||||||
|
# Create a mock object that is not LMCacheKVEvents
|
||||||
|
fake_events = MagicMock()
|
||||||
|
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
|
||||||
|
|
||||||
|
mock_connector.update_connector_output(connector_output)
|
||||||
|
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_sets_kv_cache_events_when_none(self, mock_connector):
|
||||||
|
"""Test that _kv_cache_events is set when it was None."""
|
||||||
|
kv_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
event = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1, 2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
kv_events.add_events([event])
|
||||||
|
|
||||||
|
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||||
|
|
||||||
|
mock_connector.update_connector_output(connector_output)
|
||||||
|
|
||||||
|
assert mock_connector._kv_cache_events is kv_events
|
||||||
|
|
||||||
|
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
|
||||||
|
"""Test that events are added when _kv_cache_events already exists."""
|
||||||
|
# Set up existing events
|
||||||
|
existing_events = LMCacheKVEvents(num_workers=2)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
existing_events.add_events([event1])
|
||||||
|
existing_events.add_events([event1]) # Simulate 2 workers reporting
|
||||||
|
|
||||||
|
mock_connector._kv_cache_events = existing_events
|
||||||
|
|
||||||
|
# Create new events to add
|
||||||
|
new_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
event2 = BlockStored(
|
||||||
|
block_hashes=["hash2"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
new_events.add_events([event2])
|
||||||
|
|
||||||
|
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||||
|
|
||||||
|
mock_connector.update_connector_output(connector_output)
|
||||||
|
|
||||||
|
# Check that events were added
|
||||||
|
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||||
|
assert len(all_events) == 3 # 2 from existing + 1 from new
|
||||||
|
assert event1 in all_events
|
||||||
|
assert event2 in all_events
|
||||||
|
|
||||||
|
def test_increments_workers_when_kv_cache_events_already_exists(
|
||||||
|
self, mock_connector
|
||||||
|
):
|
||||||
|
"""Test that worker count is incremented correctly."""
|
||||||
|
# Set up existing events with 2 workers
|
||||||
|
existing_events = LMCacheKVEvents(num_workers=2)
|
||||||
|
mock_connector._kv_cache_events = existing_events
|
||||||
|
|
||||||
|
# Create new events from 3 workers
|
||||||
|
new_events = LMCacheKVEvents(num_workers=3)
|
||||||
|
event = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
new_events.add_events([event])
|
||||||
|
|
||||||
|
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||||
|
|
||||||
|
mock_connector.update_connector_output(connector_output)
|
||||||
|
|
||||||
|
# Worker count should be 2 + 3 = 5
|
||||||
|
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
|
||||||
|
|
||||||
|
def test_multiple_updates(self, mock_connector):
|
||||||
|
"""Test multiple consecutive updates."""
|
||||||
|
# First update
|
||||||
|
events1 = LMCacheKVEvents(num_workers=1)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
events1.add_events([event1])
|
||||||
|
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||||
|
mock_connector.update_connector_output(output1)
|
||||||
|
|
||||||
|
# Second update
|
||||||
|
events2 = LMCacheKVEvents(num_workers=2)
|
||||||
|
event2 = BlockStored(
|
||||||
|
block_hashes=["hash2"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
events2.add_events([event2])
|
||||||
|
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||||
|
mock_connector.update_connector_output(output2)
|
||||||
|
|
||||||
|
# Third update
|
||||||
|
events3 = LMCacheKVEvents(num_workers=1)
|
||||||
|
event3 = BlockStored(
|
||||||
|
block_hashes=["hash3"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[3],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
events3.add_events([event3])
|
||||||
|
output3 = KVConnectorOutput(kv_cache_events=events3)
|
||||||
|
mock_connector.update_connector_output(output3)
|
||||||
|
|
||||||
|
# Check final state
|
||||||
|
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||||
|
assert len(all_events) == 3
|
||||||
|
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
|
||||||
|
|
||||||
|
def test_updates_with_empty_events(self, mock_connector):
|
||||||
|
"""Test updating with empty event lists."""
|
||||||
|
# First update with actual events
|
||||||
|
events1 = LMCacheKVEvents(num_workers=1)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
events1.add_events([event1])
|
||||||
|
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||||
|
mock_connector.update_connector_output(output1)
|
||||||
|
|
||||||
|
# Second update with empty events
|
||||||
|
events2 = LMCacheKVEvents(num_workers=2)
|
||||||
|
# No events added
|
||||||
|
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||||
|
mock_connector.update_connector_output(output2)
|
||||||
|
|
||||||
|
# Should still have the original event
|
||||||
|
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||||
|
assert len(all_events) == 1
|
||||||
|
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestTakeEvents:
|
||||||
|
"""Test take_events method."""
|
||||||
|
|
||||||
|
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||||
|
"""Test that nothing is yielded when _kv_cache_events is None."""
|
||||||
|
mock_connector._kv_cache_events = None
|
||||||
|
|
||||||
|
events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
assert events == []
|
||||||
|
|
||||||
|
def test_yields_events_and_clears(self, mock_connector):
|
||||||
|
"""Test that events are yielded and then cleared."""
|
||||||
|
# Set up events
|
||||||
|
kv_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
event2 = BlockStored(
|
||||||
|
block_hashes=["hash2"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
kv_events.add_events([event1, event2])
|
||||||
|
mock_connector._kv_cache_events = kv_events
|
||||||
|
|
||||||
|
# Take events
|
||||||
|
events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
# Check that events were yielded
|
||||||
|
assert len(events) == 2
|
||||||
|
assert event1 in events
|
||||||
|
assert event2 in events
|
||||||
|
|
||||||
|
# Check that _kv_cache_events was cleared
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_aggregates_before_yielding(self, mock_connector):
|
||||||
|
"""Test that events are aggregated before yielding."""
|
||||||
|
# Set up events from multiple workers
|
||||||
|
kv_events = LMCacheKVEvents(num_workers=3)
|
||||||
|
common_event = BlockStored(
|
||||||
|
block_hashes=["hash_common"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
uncommon_event = BlockStored(
|
||||||
|
block_hashes=["hash_uncommon"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
# All 3 workers report common_event
|
||||||
|
kv_events.add_events([common_event])
|
||||||
|
kv_events.add_events([common_event])
|
||||||
|
kv_events.add_events([common_event])
|
||||||
|
|
||||||
|
# Only 1 worker reports uncommon_event
|
||||||
|
kv_events.add_events([uncommon_event])
|
||||||
|
|
||||||
|
mock_connector._kv_cache_events = kv_events
|
||||||
|
|
||||||
|
# Take events
|
||||||
|
events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
# Only the common event should be yielded
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0] == common_event
|
||||||
|
|
||||||
|
def test_multiple_take_events_calls(self, mock_connector):
|
||||||
|
"""Test calling take_events multiple times."""
|
||||||
|
# First call with events
|
||||||
|
kv_events1 = LMCacheKVEvents(num_workers=1)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
kv_events1.add_events([event1])
|
||||||
|
mock_connector._kv_cache_events = kv_events1
|
||||||
|
|
||||||
|
events1 = list(mock_connector.take_events())
|
||||||
|
assert len(events1) == 1
|
||||||
|
assert events1[0] == event1
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
# Second call with no events
|
||||||
|
events2 = list(mock_connector.take_events())
|
||||||
|
assert events2 == []
|
||||||
|
|
||||||
|
# Third call after adding new events
|
||||||
|
kv_events2 = LMCacheKVEvents(num_workers=1)
|
||||||
|
event2 = BlockStored(
|
||||||
|
block_hashes=["hash2"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
kv_events2.add_events([event2])
|
||||||
|
mock_connector._kv_cache_events = kv_events2
|
||||||
|
|
||||||
|
events3 = list(mock_connector.take_events())
|
||||||
|
assert len(events3) == 1
|
||||||
|
assert events3[0] == event2
|
||||||
|
|
||||||
|
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
|
||||||
|
"""Test that nothing is yielded if aggregation removes all events."""
|
||||||
|
# Set up events from 2 workers with no common events
|
||||||
|
kv_events = LMCacheKVEvents(num_workers=2)
|
||||||
|
event1 = BlockStored(
|
||||||
|
block_hashes=["hash1"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[1],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
event2 = BlockStored(
|
||||||
|
block_hashes=["hash2"],
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[2],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Worker 1 reports event1
|
||||||
|
kv_events.add_events([event1])
|
||||||
|
# Worker 2 reports event2
|
||||||
|
kv_events.add_events([event2])
|
||||||
|
|
||||||
|
mock_connector._kv_cache_events = kv_events
|
||||||
|
|
||||||
|
# Take events
|
||||||
|
events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
# No common events, so nothing should be yielded
|
||||||
|
assert events == []
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationScenarios:
|
||||||
|
"""Test integration scenarios."""
|
||||||
|
|
||||||
|
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
|
||||||
|
"""Test a complete workflow from getting events to taking them."""
|
||||||
|
# Step 1: Get events from lmcache engine
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
mock_lmcache_engine_event
|
||||||
|
]
|
||||||
|
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert kv_events is not None
|
||||||
|
assert len(kv_events.get_all_events()) == 1
|
||||||
|
|
||||||
|
# Step 2: Update connector output (simulate receiving from worker)
|
||||||
|
output1 = KVConnectorOutput(kv_cache_events=kv_events)
|
||||||
|
mock_connector.update_connector_output(output1)
|
||||||
|
|
||||||
|
assert mock_connector._kv_cache_events is not None
|
||||||
|
|
||||||
|
# Step 3: Take events
|
||||||
|
taken_events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
assert len(taken_events) == 1
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_multiple_workers_workflow(self, mock_connector):
|
||||||
|
"""Test workflow with multiple workers."""
|
||||||
|
|
||||||
|
class MockEvent:
|
||||||
|
def __init__(self, hash_val):
|
||||||
|
self.block_hashes = [hash_val]
|
||||||
|
self.parent_block_hash = None
|
||||||
|
self.token_ids = [1]
|
||||||
|
self.lora_id = None
|
||||||
|
self.block_size = 16
|
||||||
|
self.medium = "GPU"
|
||||||
|
|
||||||
|
# Worker 1
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
MockEvent("hash_common"),
|
||||||
|
MockEvent("hash_worker1"),
|
||||||
|
]
|
||||||
|
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
|
||||||
|
mock_connector.update_connector_output(output1)
|
||||||
|
|
||||||
|
# Worker 2
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
MockEvent("hash_common"),
|
||||||
|
MockEvent("hash_worker2"),
|
||||||
|
]
|
||||||
|
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
|
||||||
|
mock_connector.update_connector_output(output2)
|
||||||
|
|
||||||
|
# Take events (should only get common events)
|
||||||
|
taken_events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
# With aggregation, only events reported by both workers should be present
|
||||||
|
# In this case, hash_common was reported by both
|
||||||
|
event_hashes = [e.block_hashes[0] for e in taken_events]
|
||||||
|
assert "hash_common" in event_hashes
|
||||||
|
|
||||||
|
def test_empty_workflow(self, mock_connector):
|
||||||
|
"""Test workflow when there are no events at any stage."""
|
||||||
|
# Get events returns None
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||||
|
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
assert kv_events is None
|
||||||
|
|
||||||
|
# Update with None
|
||||||
|
output = KVConnectorOutput(kv_cache_events=None)
|
||||||
|
mock_connector.update_connector_output(output)
|
||||||
|
|
||||||
|
# Take events
|
||||||
|
taken_events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
assert taken_events == []
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_repeated_cycles(self, mock_connector):
|
||||||
|
"""Test multiple cycles of the complete workflow."""
|
||||||
|
|
||||||
|
class MockEvent:
|
||||||
|
def __init__(self, cycle_num):
|
||||||
|
self.block_hashes = [f"hash_cycle_{cycle_num}"]
|
||||||
|
self.parent_block_hash = None
|
||||||
|
self.token_ids = [cycle_num]
|
||||||
|
self.lora_id = None
|
||||||
|
self.block_size = 16
|
||||||
|
self.medium = "GPU"
|
||||||
|
|
||||||
|
for cycle in range(3):
|
||||||
|
# Get events
|
||||||
|
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||||
|
MockEvent(cycle)
|
||||||
|
]
|
||||||
|
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||||
|
|
||||||
|
# Update
|
||||||
|
output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||||
|
mock_connector.update_connector_output(output)
|
||||||
|
|
||||||
|
# Take
|
||||||
|
taken_events = list(mock_connector.take_events())
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert len(taken_events) == 1
|
||||||
|
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
|
||||||
|
assert mock_connector._kv_cache_events is None
|
||||||
|
|
||||||
|
def test_lmcache_kv_events_aggregation(self):
|
||||||
|
"""
|
||||||
|
Test LMCacheKVEvents aggregation across TP ranks using
|
||||||
|
KVOutputAggregator (used by MultiprocExecutor).
|
||||||
|
"""
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
|
# Create KVOutputAggregator for 3 workers (simulating TP=3)
|
||||||
|
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||||
|
|
||||||
|
# Define common and unique events
|
||||||
|
common_event = BlockStored(
|
||||||
|
block_hashes=["hash_common"],
|
||||||
|
parent_block_hash="parent_common",
|
||||||
|
token_ids=[1, 2, 3],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
worker1_unique_event = BlockStored(
|
||||||
|
block_hashes=["hash_worker1"],
|
||||||
|
parent_block_hash="parent_w1",
|
||||||
|
token_ids=[4, 5],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
worker2_unique_event = BlockStored(
|
||||||
|
block_hashes=["hash_worker2"],
|
||||||
|
parent_block_hash="parent_w2",
|
||||||
|
token_ids=[6, 7],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
worker3_unique_event = BlockStored(
|
||||||
|
block_hashes=["hash_worker3"],
|
||||||
|
parent_block_hash="parent_w3",
|
||||||
|
token_ids=[8, 9],
|
||||||
|
block_size=16,
|
||||||
|
lora_id=None,
|
||||||
|
medium="GPU",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create events for each worker
|
||||||
|
# Worker 0: reports common event and its unique event
|
||||||
|
worker0_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
worker0_events.add_events([common_event, worker1_unique_event])
|
||||||
|
|
||||||
|
# Worker 1: reports common event and its unique event
|
||||||
|
worker1_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
worker1_events.add_events([common_event, worker2_unique_event])
|
||||||
|
|
||||||
|
# Worker 2: reports common event and its unique event
|
||||||
|
worker2_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
worker2_events.add_events([common_event, worker3_unique_event])
|
||||||
|
|
||||||
|
# Create ModelRunnerOutput instances for each worker
|
||||||
|
worker_outputs = []
|
||||||
|
for i, worker_events in enumerate(
|
||||||
|
[worker0_events, worker1_events, worker2_events]
|
||||||
|
):
|
||||||
|
output = ModelRunnerOutput(
|
||||||
|
req_ids=[f"req_{i}"],
|
||||||
|
req_id_to_index={f"req_{i}": 0},
|
||||||
|
sampled_token_ids=[[123]], # dummy token
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[None],
|
||||||
|
kv_connector_output=KVConnectorOutput(
|
||||||
|
finished_sending=set([f"req_{i}_send"])
|
||||||
|
if i < 2
|
||||||
|
else None, # Workers 0,1 finished sending
|
||||||
|
finished_recving=set([f"req_{i}_recv"])
|
||||||
|
if i > 0
|
||||||
|
else None, # Workers 1,2 finished receiving
|
||||||
|
kv_cache_events=worker_events,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
worker_outputs.append(output)
|
||||||
|
|
||||||
|
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||||
|
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||||
|
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
|
||||||
|
|
||||||
|
assert isinstance(kv_cache_events, LMCacheKVEvents)
|
||||||
|
|
||||||
|
# After aggregation, events should be combined from all workers
|
||||||
|
# The aggregator doesn't automatically aggregate events, so we need to call
|
||||||
|
# aggregate() to get only common events
|
||||||
|
kv_cache_events.aggregate()
|
||||||
|
aggregated_events = kv_cache_events.get_all_events()
|
||||||
|
|
||||||
|
# Only the common event should remain after aggregation
|
||||||
|
# because it's the only event reported by all 3 workers
|
||||||
|
assert len(aggregated_events) == 1
|
||||||
|
assert aggregated_events[0] == common_event
|
||||||
|
|
||||||
|
# Verify the common event properties
|
||||||
|
assert aggregated_events[0].block_hashes == ["hash_common"]
|
||||||
|
assert aggregated_events[0].parent_block_hash == "parent_common"
|
||||||
|
assert aggregated_events[0].token_ids == [1, 2, 3]
|
||||||
@ -5,7 +5,7 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import Counter, deque
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from itertools import count
|
from itertools import count
|
||||||
@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
|
|||||||
lora_id: int | None
|
lora_id: int | None
|
||||||
medium: str | None
|
medium: str | None
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(
|
||||||
|
(
|
||||||
|
tuple(self.block_hashes),
|
||||||
|
self.parent_block_hash,
|
||||||
|
tuple(self.token_ids),
|
||||||
|
self.block_size,
|
||||||
|
self.lora_id,
|
||||||
|
self.medium,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[ExternalBlockHash]
|
block_hashes: list[ExternalBlockHash]
|
||||||
medium: str | None
|
medium: str | None
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((tuple(self.block_hashes), self.medium))
|
||||||
|
|
||||||
|
|
||||||
class AllBlocksCleared(KVCacheEvent):
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
pass
|
pass
|
||||||
@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
|
|||||||
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
|
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventAggregator:
|
||||||
|
"""
|
||||||
|
Aggregates KV events across multiple workers.
|
||||||
|
Tracks how many times each event appears and returns only those
|
||||||
|
that were emitted by all workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_event_counter", "_num_workers")
|
||||||
|
|
||||||
|
def __init__(self, num_workers: int) -> None:
|
||||||
|
if num_workers <= 0:
|
||||||
|
raise ValueError("num_workers must be greater than zero.")
|
||||||
|
self._event_counter: Counter[KVCacheEvent] = Counter()
|
||||||
|
self._num_workers: int = num_workers
|
||||||
|
|
||||||
|
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||||
|
"""
|
||||||
|
Add events from a worker batch.
|
||||||
|
|
||||||
|
:param events: List of KVCacheEvent objects.
|
||||||
|
"""
|
||||||
|
if not isinstance(events, list):
|
||||||
|
raise TypeError("events must be a list of KVCacheEvent.")
|
||||||
|
self._event_counter.update(events)
|
||||||
|
|
||||||
|
def get_common_events(self) -> list[KVCacheEvent]:
|
||||||
|
"""
|
||||||
|
Return events that appeared in all workers.
|
||||||
|
|
||||||
|
:return: List of events present in all workers.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
event
|
||||||
|
for event, count in self._event_counter.items()
|
||||||
|
if count == self._num_workers
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_all_events(self) -> list[KVCacheEvent]:
|
||||||
|
"""
|
||||||
|
Return all events for all workers.
|
||||||
|
|
||||||
|
:return: List of events for all workers.
|
||||||
|
"""
|
||||||
|
return list(self._event_counter.elements())
|
||||||
|
|
||||||
|
def clear_events(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear all tracked events.
|
||||||
|
"""
|
||||||
|
self._event_counter.clear()
|
||||||
|
|
||||||
|
def increment_workers(self, count: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
Increment the number of workers contributing events.
|
||||||
|
|
||||||
|
:param count: Number to increment the workers by.
|
||||||
|
"""
|
||||||
|
if count <= 0:
|
||||||
|
raise ValueError("count must be positive.")
|
||||||
|
self._num_workers += count
|
||||||
|
|
||||||
|
def reset_workers(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset the number of workers to 1.
|
||||||
|
"""
|
||||||
|
self._num_workers = 1
|
||||||
|
|
||||||
|
def get_number_of_workers(self) -> int:
|
||||||
|
"""
|
||||||
|
Return the number of workers.
|
||||||
|
|
||||||
|
:return: int number of workers.
|
||||||
|
"""
|
||||||
|
return self._num_workers
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"<KVEventAggregator workers={self._num_workers}, "
|
||||||
|
f"events={len(self._event_counter)}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorKVEvents(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for KV events.
|
||||||
|
Acts as a container for KV events from the connector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def aggregate(self) -> "KVConnectorKVEvents":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def increment_workers(self, count: int = 1) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_events(self) -> list[KVCacheEvent]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_number_of_workers(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear_events(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class EventPublisher(ABC):
|
class EventPublisher(ABC):
|
||||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||||
support.
|
support.
|
||||||
|
|||||||
@ -78,6 +78,7 @@ class KVOutputAggregator:
|
|||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
aggregated_kv_connector_stats = None
|
aggregated_kv_connector_stats = None
|
||||||
|
combined_kv_cache_events = None
|
||||||
invalid_block_ids = set[int]()
|
invalid_block_ids = set[int]()
|
||||||
for model_runner_output in outputs:
|
for model_runner_output in outputs:
|
||||||
assert model_runner_output is not None
|
assert model_runner_output is not None
|
||||||
@ -119,6 +120,19 @@ class KVOutputAggregator:
|
|||||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Combine kv_cache_events from all workers.
|
||||||
|
if combined_kv_cache_events is None:
|
||||||
|
# Use the first worker's kv_cache events as start event list.
|
||||||
|
combined_kv_cache_events = kv_output.kv_cache_events
|
||||||
|
elif kv_cache_events := kv_output.kv_cache_events:
|
||||||
|
assert isinstance(
|
||||||
|
combined_kv_cache_events,
|
||||||
|
type(kv_cache_events),
|
||||||
|
)
|
||||||
|
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||||
|
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||||
|
combined_kv_cache_events.increment_workers(1)
|
||||||
|
|
||||||
invalid_block_ids |= kv_output.invalid_block_ids
|
invalid_block_ids |= kv_output.invalid_block_ids
|
||||||
|
|
||||||
# select output of the worker specified by output_rank
|
# select output of the worker specified by output_rank
|
||||||
@ -129,6 +143,7 @@ class KVOutputAggregator:
|
|||||||
finished_sending=finished_sending or None,
|
finished_sending=finished_sending or None,
|
||||||
finished_recving=finished_recving or None,
|
finished_recving=finished_recving or None,
|
||||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||||
|
kv_cache_events=combined_kv_cache_events or None,
|
||||||
invalid_block_ids=invalid_block_ids,
|
invalid_block_ids=invalid_block_ids,
|
||||||
expected_finished_count=self._expected_finished_count,
|
expected_finished_count=self._expected_finished_count,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
KVConnectorPromMetrics,
|
KVConnectorPromMetrics,
|
||||||
KVConnectorStats,
|
KVConnectorStats,
|
||||||
@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||||
|
"""
|
||||||
|
Get the KV connector kv cache events collected during the last interval.
|
||||||
|
This function should be called by the model runner every time after the
|
||||||
|
model execution and before cleanup.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
"""
|
"""
|
||||||
Get the KVConnector handshake metadata for this connector.
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
|||||||
@ -1,14 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
|
||||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
|
||||||
)
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_events import (
|
||||||
|
BlockStored,
|
||||||
|
KVCacheEvent,
|
||||||
|
KVConnectorKVEvents,
|
||||||
|
KVEventAggregator,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
KVConnectorMetadata,
|
KVConnectorMetadata,
|
||||||
@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
@ -26,6 +31,44 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||||
|
"""
|
||||||
|
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_workers: int) -> None:
|
||||||
|
self._aggregator = KVEventAggregator(num_workers)
|
||||||
|
|
||||||
|
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||||
|
self._aggregator.add_events(events)
|
||||||
|
|
||||||
|
def aggregate(self) -> "LMCacheKVEvents":
|
||||||
|
"""
|
||||||
|
Aggregate KV events and retain only common events.
|
||||||
|
"""
|
||||||
|
common_events = self._aggregator.get_common_events()
|
||||||
|
self._aggregator.clear_events()
|
||||||
|
self._aggregator.add_events(common_events)
|
||||||
|
self._aggregator.reset_workers()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def increment_workers(self, count: int = 1) -> None:
|
||||||
|
self._aggregator.increment_workers(count)
|
||||||
|
|
||||||
|
def get_all_events(self) -> list[KVCacheEvent]:
|
||||||
|
return self._aggregator.get_all_events()
|
||||||
|
|
||||||
|
def get_number_of_workers(self) -> int:
|
||||||
|
return self._aggregator.get_number_of_workers()
|
||||||
|
|
||||||
|
def clear_events(self) -> None:
|
||||||
|
self._aggregator.clear_events()
|
||||||
|
self._aggregator.reset_workers()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||||
|
|
||||||
|
|
||||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
cls = _adapter.LMCacheConnectorV1Impl
|
cls = _adapter.LMCacheConnectorV1Impl
|
||||||
else:
|
else:
|
||||||
logger.info("Initializing latest dev LMCache connector")
|
logger.info("Initializing latest dev LMCache connector")
|
||||||
|
# lazy import
|
||||||
|
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||||
|
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||||
|
)
|
||||||
|
|
||||||
cls = LMCacheConnectorLatestImpl
|
cls = LMCacheConnectorLatestImpl
|
||||||
|
|
||||||
self._lmcache_engine = cls(vllm_config, role, self)
|
self._lmcache_engine = cls(vllm_config, role, self)
|
||||||
|
|
||||||
|
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Worker-side methods
|
# Worker-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
# Fallback for older versions that don't support this method
|
# Fallback for older versions that don't support this method
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||||
|
"""
|
||||||
|
Get the KV connector kv cache events collected during the last interval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||||
|
if not events:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blocks: list[BlockStored] = [
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=e.block_hashes,
|
||||||
|
parent_block_hash=e.parent_block_hash,
|
||||||
|
token_ids=e.token_ids,
|
||||||
|
lora_id=e.lora_id,
|
||||||
|
block_size=e.block_size,
|
||||||
|
medium=e.medium,
|
||||||
|
)
|
||||||
|
for e in events
|
||||||
|
]
|
||||||
|
|
||||||
|
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||||
|
lmcache_kv_events.add_events(blocks)
|
||||||
|
return lmcache_kv_events
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
"""
|
"""
|
||||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||||
|
|
||||||
|
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||||
|
"""
|
||||||
|
Update KVConnector state from worker-side connectors output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_output (KVConnectorOutput): the worker-side
|
||||||
|
connectors output.
|
||||||
|
"""
|
||||||
|
# Get the KV events
|
||||||
|
kv_cache_events = connector_output.kv_cache_events
|
||||||
|
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._kv_cache_events is None:
|
||||||
|
self._kv_cache_events = kv_cache_events
|
||||||
|
else:
|
||||||
|
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||||
|
self._kv_cache_events.increment_workers(
|
||||||
|
kv_cache_events.get_number_of_workers()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
|||||||
returned by the engine.
|
returned by the engine.
|
||||||
"""
|
"""
|
||||||
return self._lmcache_engine.request_finished(request, block_ids)
|
return self._lmcache_engine.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||||
|
"""
|
||||||
|
Take the KV cache events from the connector.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
New KV cache events since the last call.
|
||||||
|
"""
|
||||||
|
if self._kv_cache_events is not None:
|
||||||
|
self._kv_cache_events.aggregate()
|
||||||
|
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||||
|
yield from kv_cache_events
|
||||||
|
self._kv_cache_events.clear_events()
|
||||||
|
self._kv_cache_events = None
|
||||||
|
|||||||
@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||||
return agg_block_ids
|
return agg_block_ids
|
||||||
|
|
||||||
|
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
|
||||||
|
# for the MultiConnector. It should be able to get events from multiple
|
||||||
|
# connectors, handling the case where only a subset of the requested connectors
|
||||||
|
# implements the 'get_kv_connector_kv_cache_events'
|
||||||
|
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|||||||
@ -12,9 +12,11 @@ from vllm.compilation.cuda_graph import CUDAGraphStat
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.distributed.kv_events import KVConnectorKVEvents
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||||
else:
|
else:
|
||||||
KVConnectorStats = object
|
KVConnectorStats = object
|
||||||
|
KVConnectorKVEvents = object
|
||||||
|
|
||||||
|
|
||||||
class LogprobsLists(NamedTuple):
|
class LogprobsLists(NamedTuple):
|
||||||
@ -108,6 +110,7 @@ class KVConnectorOutput:
|
|||||||
finished_sending: set[str] | None = None
|
finished_sending: set[str] | None = None
|
||||||
finished_recving: set[str] | None = None
|
finished_recving: set[str] | None = None
|
||||||
kv_connector_stats: KVConnectorStats | None = None
|
kv_connector_stats: KVConnectorStats | None = None
|
||||||
|
kv_cache_events: KVConnectorKVEvents | None = None
|
||||||
# IDs of externally computed KV blocks that failed to load.
|
# IDs of externally computed KV blocks that failed to load.
|
||||||
# Requests referencing these blocks should be rescheduled to recompute them
|
# Requests referencing these blocks should be rescheduled to recompute them
|
||||||
invalid_block_ids: set[int] = field(default_factory=set)
|
invalid_block_ids: set[int] = field(default_factory=set)
|
||||||
@ -123,6 +126,7 @@ class KVConnectorOutput:
|
|||||||
not self.finished_sending
|
not self.finished_sending
|
||||||
and not self.finished_recving
|
and not self.finished_recving
|
||||||
and not self.kv_connector_stats
|
and not self.kv_connector_stats
|
||||||
|
and not self.kv_cache_events
|
||||||
and not self.invalid_block_ids
|
and not self.invalid_block_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import (
|
|||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||||
@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin:
|
|||||||
)
|
)
|
||||||
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
|
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
|
||||||
|
|
||||||
output.kv_connector_stats = (
|
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||||
KVConnectorModelRunnerMixin.get_kv_connector_stats()
|
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||||
)
|
|
||||||
kv_connector.clear_connector_metadata()
|
|
||||||
|
|
||||||
@staticmethod
|
kv_connector.clear_connector_metadata()
|
||||||
def get_kv_connector_stats() -> KVConnectorStats | None:
|
|
||||||
if has_kv_transfer_group():
|
|
||||||
return get_kv_transfer_group().get_kv_connector_stats()
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def use_uniform_kv_cache(
|
def use_uniform_kv_cache(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user