mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 03:55:01 +08:00
[KVConnector] Add KV events to KV Connectors (#28309)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
parent
a11f4a81e0
commit
f4417f8449
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
@ -0,0 +1,756 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
|
||||
LMCacheConnectorV1,
|
||||
LMCacheKVEvents,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lmcache_engine_event():
|
||||
"""Create a mock event object that mimics what the lmcache engine returns."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes,
|
||||
parent_block_hash,
|
||||
token_ids,
|
||||
lora_id,
|
||||
block_size,
|
||||
medium,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
self.parent_block_hash = parent_block_hash
|
||||
self.token_ids = token_ids
|
||||
self.lora_id = lora_id
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
return MockEvent(
|
||||
block_hashes=["hash1", "hash2"],
|
||||
parent_block_hash="parent_hash",
|
||||
token_ids=[1, 2, 3, 4],
|
||||
lora_id=None,
|
||||
block_size=16,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connector():
|
||||
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
|
||||
connector = MagicMock(spec=LMCacheConnectorV1)
|
||||
connector._kv_cache_events = None
|
||||
connector._lmcache_engine = MagicMock()
|
||||
|
||||
# Make the methods use the real implementation
|
||||
connector.get_kv_connector_kv_cache_events = (
|
||||
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.update_connector_output = (
|
||||
LMCacheConnectorV1.update_connector_output.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.take_events = LMCacheConnectorV1.take_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
class TestGetKVConnectorKVCacheEvents:
|
||||
"""Test get_kv_connector_kv_cache_events method."""
|
||||
|
||||
def test_returns_none_when_no_events(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine has no events."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
|
||||
|
||||
def test_returns_none_when_empty_list(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine returns empty list."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = []
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test conversion of a single event from lmcache engine format."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
assert result.get_number_of_workers() == 1
|
||||
|
||||
events = result.get_all_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].block_hashes == ["hash1", "hash2"]
|
||||
assert events[0].parent_block_hash == "parent_hash"
|
||||
assert events[0].token_ids == [1, 2, 3, 4]
|
||||
assert events[0].lora_id is None
|
||||
assert events[0].block_size == 16
|
||||
assert events[0].medium == "GPU"
|
||||
|
||||
def test_converts_multiple_events(self, mock_connector):
|
||||
"""Test conversion of multiple events from lmcache engine format."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, i):
|
||||
self.block_hashes = [f"hash{i}"]
|
||||
self.parent_block_hash = f"parent{i}"
|
||||
self.token_ids = [i]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
events = [MockEvent(i) for i in range(5)]
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = events
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
|
||||
converted_events = result.get_all_events()
|
||||
assert len(converted_events) == 5
|
||||
|
||||
for i, event in enumerate(converted_events):
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == [f"hash{i}"]
|
||||
assert event.parent_block_hash == f"parent{i}"
|
||||
assert event.token_ids == [i]
|
||||
|
||||
def test_preserves_event_attributes(self, mock_connector):
|
||||
"""Test that all event attributes are correctly preserved."""
|
||||
|
||||
class MockEventWithLora:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
|
||||
self.parent_block_hash = "parent_xyz"
|
||||
self.token_ids = [100, 200, 300]
|
||||
self.lora_id = 42
|
||||
self.block_size = 32
|
||||
self.medium = "DISK"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventWithLora()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
event = events[0]
|
||||
|
||||
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
|
||||
assert event.parent_block_hash == "parent_xyz"
|
||||
assert event.token_ids == [100, 200, 300]
|
||||
assert event.lora_id == 42
|
||||
assert event.block_size == 32
|
||||
assert event.medium == "DISK"
|
||||
|
||||
def test_handles_none_parent_block_hash(self, mock_connector):
|
||||
"""Test handling of events with None parent_block_hash."""
|
||||
|
||||
class MockEventNoParent:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash1"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1, 2]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventNoParent()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
assert events[0].parent_block_hash is None
|
||||
|
||||
|
||||
class TestUpdateConnectorOutput:
|
||||
"""Test update_connector_output method."""
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that method returns early when kv_cache_events is None."""
|
||||
connector_output = KVConnectorOutput(kv_cache_events=None)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that method returns early when kv_cache_events is not
|
||||
LMCacheKVEvents."""
|
||||
# Create a mock object that is not LMCacheKVEvents
|
||||
fake_events = MagicMock()
|
||||
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_sets_kv_cache_events_when_none(self, mock_connector):
|
||||
"""Test that _kv_cache_events is set when it was None."""
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1, 2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is kv_events
|
||||
|
||||
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
|
||||
"""Test that events are added when _kv_cache_events already exists."""
|
||||
# Set up existing events
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
existing_events.add_events([event1])
|
||||
existing_events.add_events([event1]) # Simulate 2 workers reporting
|
||||
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events to add
|
||||
new_events = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event2])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Check that events were added
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3 # 2 from existing + 1 from new
|
||||
assert event1 in all_events
|
||||
assert event2 in all_events
|
||||
|
||||
def test_increments_workers_when_kv_cache_events_already_exists(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that worker count is incremented correctly."""
|
||||
# Set up existing events with 2 workers
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events from 3 workers
|
||||
new_events = LMCacheKVEvents(num_workers=3)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Worker count should be 2 + 3 = 5
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
|
||||
|
||||
def test_multiple_updates(self, mock_connector):
|
||||
"""Test multiple consecutive updates."""
|
||||
# First update
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events2.add_events([event2])
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Third update
|
||||
events3 = LMCacheKVEvents(num_workers=1)
|
||||
event3 = BlockStored(
|
||||
block_hashes=["hash3"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events3.add_events([event3])
|
||||
output3 = KVConnectorOutput(kv_cache_events=events3)
|
||||
mock_connector.update_connector_output(output3)
|
||||
|
||||
# Check final state
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
|
||||
|
||||
def test_updates_with_empty_events(self, mock_connector):
|
||||
"""Test updating with empty event lists."""
|
||||
# First update with actual events
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update with empty events
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
# No events added
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Should still have the original event
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 1
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
|
||||
|
||||
|
||||
class TestTakeEvents:
|
||||
"""Test take_events method."""
|
||||
|
||||
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that nothing is yielded when _kv_cache_events is None."""
|
||||
mock_connector._kv_cache_events = None
|
||||
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
assert events == []
|
||||
|
||||
def test_yields_events_and_clears(self, mock_connector):
|
||||
"""Test that events are yielded and then cleared."""
|
||||
# Set up events
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event1, event2])
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Check that events were yielded
|
||||
assert len(events) == 2
|
||||
assert event1 in events
|
||||
assert event2 in events
|
||||
|
||||
# Check that _kv_cache_events was cleared
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_aggregates_before_yielding(self, mock_connector):
|
||||
"""Test that events are aggregated before yielding."""
|
||||
# Set up events from multiple workers
|
||||
kv_events = LMCacheKVEvents(num_workers=3)
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
uncommon_event = BlockStored(
|
||||
block_hashes=["hash_uncommon"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# All 3 workers report common_event
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
|
||||
# Only 1 worker reports uncommon_event
|
||||
kv_events.add_events([uncommon_event])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Only the common event should be yielded
|
||||
assert len(events) == 1
|
||||
assert events[0] == common_event
|
||||
|
||||
def test_multiple_take_events_calls(self, mock_connector):
|
||||
"""Test calling take_events multiple times."""
|
||||
# First call with events
|
||||
kv_events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events1.add_events([event1])
|
||||
mock_connector._kv_cache_events = kv_events1
|
||||
|
||||
events1 = list(mock_connector.take_events())
|
||||
assert len(events1) == 1
|
||||
assert events1[0] == event1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
# Second call with no events
|
||||
events2 = list(mock_connector.take_events())
|
||||
assert events2 == []
|
||||
|
||||
# Third call after adding new events
|
||||
kv_events2 = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events2.add_events([event2])
|
||||
mock_connector._kv_cache_events = kv_events2
|
||||
|
||||
events3 = list(mock_connector.take_events())
|
||||
assert len(events3) == 1
|
||||
assert events3[0] == event2
|
||||
|
||||
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
|
||||
"""Test that nothing is yielded if aggregation removes all events."""
|
||||
# Set up events from 2 workers with no common events
|
||||
kv_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Worker 1 reports event1
|
||||
kv_events.add_events([event1])
|
||||
# Worker 2 reports event2
|
||||
kv_events.add_events([event2])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# No common events, so nothing should be yielded
|
||||
assert events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test a complete workflow from getting events to taking them."""
|
||||
# Step 1: Get events from lmcache engine
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is not None
|
||||
assert len(kv_events.get_all_events()) == 1
|
||||
|
||||
# Step 2: Update connector output (simulate receiving from worker)
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
assert mock_connector._kv_cache_events is not None
|
||||
|
||||
# Step 3: Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert len(taken_events) == 1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_multiple_workers_workflow(self, mock_connector):
|
||||
"""Test workflow with multiple workers."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, hash_val):
|
||||
self.block_hashes = [hash_val]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
# Worker 1
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker1"),
|
||||
]
|
||||
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Worker 2
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker2"),
|
||||
]
|
||||
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Take events (should only get common events)
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# With aggregation, only events reported by both workers should be present
|
||||
# In this case, hash_common was reported by both
|
||||
event_hashes = [e.block_hashes[0] for e in taken_events]
|
||||
assert "hash_common" in event_hashes
|
||||
|
||||
def test_empty_workflow(self, mock_connector):
|
||||
"""Test workflow when there are no events at any stage."""
|
||||
# Get events returns None
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is None
|
||||
|
||||
# Update with None
|
||||
output = KVConnectorOutput(kv_cache_events=None)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert taken_events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_repeated_cycles(self, mock_connector):
|
||||
"""Test multiple cycles of the complete workflow."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, cycle_num):
|
||||
self.block_hashes = [f"hash_cycle_{cycle_num}"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [cycle_num]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
for cycle in range(3):
|
||||
# Get events
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent(cycle)
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
# Update
|
||||
output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# Verify
|
||||
assert len(taken_events) == 1
|
||||
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_lmcache_kv_events_aggregation(self):
|
||||
"""
|
||||
Test LMCacheKVEvents aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Define common and unique events
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash="parent_common",
|
||||
token_ids=[1, 2, 3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker1_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker1"],
|
||||
parent_block_hash="parent_w1",
|
||||
token_ids=[4, 5],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker2_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker2"],
|
||||
parent_block_hash="parent_w2",
|
||||
token_ids=[6, 7],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker3_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker3"],
|
||||
parent_block_hash="parent_w3",
|
||||
token_ids=[8, 9],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Create events for each worker
|
||||
# Worker 0: reports common event and its unique event
|
||||
worker0_events = LMCacheKVEvents(num_workers=1)
|
||||
worker0_events.add_events([common_event, worker1_unique_event])
|
||||
|
||||
# Worker 1: reports common event and its unique event
|
||||
worker1_events = LMCacheKVEvents(num_workers=1)
|
||||
worker1_events.add_events([common_event, worker2_unique_event])
|
||||
|
||||
# Worker 2: reports common event and its unique event
|
||||
worker2_events = LMCacheKVEvents(num_workers=1)
|
||||
worker2_events.add_events([common_event, worker3_unique_event])
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_events in enumerate(
|
||||
[worker0_events, worker1_events, worker2_events]
|
||||
):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2
|
||||
else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0
|
||||
else None, # Workers 1,2 finished receiving
|
||||
kv_cache_events=worker_events,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
|
||||
|
||||
assert isinstance(kv_cache_events, LMCacheKVEvents)
|
||||
|
||||
# After aggregation, events should be combined from all workers
|
||||
# The aggregator doesn't automatically aggregate events, so we need to call
|
||||
# aggregate() to get only common events
|
||||
kv_cache_events.aggregate()
|
||||
aggregated_events = kv_cache_events.get_all_events()
|
||||
|
||||
# Only the common event should remain after aggregation
|
||||
# because it's the only event reported by all 3 workers
|
||||
assert len(aggregated_events) == 1
|
||||
assert aggregated_events[0] == common_event
|
||||
|
||||
# Verify the common event properties
|
||||
assert aggregated_events[0].block_hashes == ["hash_common"]
|
||||
assert aggregated_events[0].parent_block_hash == "parent_common"
|
||||
assert aggregated_events[0].token_ids == [1, 2, 3]
|
||||
@ -5,7 +5,7 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections import Counter, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from itertools import count
|
||||
@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
|
||||
lora_id: int | None
|
||||
medium: str | None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
tuple(self.block_hashes),
|
||||
self.parent_block_hash,
|
||||
tuple(self.token_ids),
|
||||
self.block_size,
|
||||
self.lora_id,
|
||||
self.medium,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
medium: str | None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((tuple(self.block_hashes), self.medium))
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
|
||||
events: list[BlockStored | BlockRemoved | AllBlocksCleared]
|
||||
|
||||
|
||||
class KVEventAggregator:
|
||||
"""
|
||||
Aggregates KV events across multiple workers.
|
||||
Tracks how many times each event appears and returns only those
|
||||
that were emitted by all workers.
|
||||
"""
|
||||
|
||||
__slots__ = ("_event_counter", "_num_workers")
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
if num_workers <= 0:
|
||||
raise ValueError("num_workers must be greater than zero.")
|
||||
self._event_counter: Counter[KVCacheEvent] = Counter()
|
||||
self._num_workers: int = num_workers
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
"""
|
||||
Add events from a worker batch.
|
||||
|
||||
:param events: List of KVCacheEvent objects.
|
||||
"""
|
||||
if not isinstance(events, list):
|
||||
raise TypeError("events must be a list of KVCacheEvent.")
|
||||
self._event_counter.update(events)
|
||||
|
||||
def get_common_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return events that appeared in all workers.
|
||||
|
||||
:return: List of events present in all workers.
|
||||
"""
|
||||
return [
|
||||
event
|
||||
for event, count in self._event_counter.items()
|
||||
if count == self._num_workers
|
||||
]
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
"""
|
||||
Return all events for all workers.
|
||||
|
||||
:return: List of events for all workers.
|
||||
"""
|
||||
return list(self._event_counter.elements())
|
||||
|
||||
def clear_events(self) -> None:
|
||||
"""
|
||||
Clear all tracked events.
|
||||
"""
|
||||
self._event_counter.clear()
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
"""
|
||||
Increment the number of workers contributing events.
|
||||
|
||||
:param count: Number to increment the workers by.
|
||||
"""
|
||||
if count <= 0:
|
||||
raise ValueError("count must be positive.")
|
||||
self._num_workers += count
|
||||
|
||||
def reset_workers(self) -> None:
|
||||
"""
|
||||
Reset the number of workers to 1.
|
||||
"""
|
||||
self._num_workers = 1
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
"""
|
||||
Return the number of workers.
|
||||
|
||||
:return: int number of workers.
|
||||
"""
|
||||
return self._num_workers
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<KVEventAggregator workers={self._num_workers}, "
|
||||
f"events={len(self._event_counter)}>"
|
||||
)
|
||||
|
||||
|
||||
class KVConnectorKVEvents(ABC):
|
||||
"""
|
||||
Abstract base class for KV events.
|
||||
Acts as a container for KV events from the connector.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self) -> "KVConnectorKVEvents":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_number_of_workers(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
support.
|
||||
|
||||
@ -78,6 +78,7 @@ class KVOutputAggregator:
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
assert model_runner_output is not None
|
||||
@ -119,6 +120,19 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
combined_kv_cache_events = kv_output.kv_cache_events
|
||||
elif kv_cache_events := kv_output.kv_cache_events:
|
||||
assert isinstance(
|
||||
combined_kv_cache_events,
|
||||
type(kv_cache_events),
|
||||
)
|
||||
worker_kv_cache_events = kv_cache_events.get_all_events()
|
||||
combined_kv_cache_events.add_events(worker_kv_cache_events)
|
||||
combined_kv_cache_events.increment_workers(1)
|
||||
|
||||
invalid_block_ids |= kv_output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
@ -129,6 +143,7 @@ class KVOutputAggregator:
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
model execution and before cleanup.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
@ -26,6 +31,44 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
|
||||
@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||
return agg_block_ids
|
||||
|
||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
|
||||
# for the MultiConnector. It should be able to get events from multiple
|
||||
# connectors, handling the case where only a subset of the requested connectors
|
||||
# implements the 'get_kv_connector_kv_cache_events'
|
||||
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@ -12,9 +12,11 @@ from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
else:
|
||||
KVConnectorStats = object
|
||||
KVConnectorKVEvents = object
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
@ -108,6 +110,7 @@ class KVConnectorOutput:
|
||||
finished_sending: set[str] | None = None
|
||||
finished_recving: set[str] | None = None
|
||||
kv_connector_stats: KVConnectorStats | None = None
|
||||
kv_cache_events: KVConnectorKVEvents | None = None
|
||||
# IDs of externally computed KV blocks that failed to load.
|
||||
# Requests referencing these blocks should be rescheduled to recompute them
|
||||
invalid_block_ids: set[int] = field(default_factory=set)
|
||||
@ -123,6 +126,7 @@ class KVConnectorOutput:
|
||||
not self.finished_sending
|
||||
and not self.finished_recving
|
||||
and not self.kv_connector_stats
|
||||
and not self.kv_cache_events
|
||||
and not self.invalid_block_ids
|
||||
)
|
||||
|
||||
|
||||
@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||
@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin:
|
||||
)
|
||||
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
|
||||
|
||||
output.kv_connector_stats = (
|
||||
KVConnectorModelRunnerMixin.get_kv_connector_stats()
|
||||
)
|
||||
kv_connector.clear_connector_metadata()
|
||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
@staticmethod
|
||||
def get_kv_connector_stats() -> KVConnectorStats | None:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_kv_connector_stats()
|
||||
return None
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
@staticmethod
|
||||
def use_uniform_kv_cache(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user