[KVConnector] Add KV events to KV Connectors (#28309)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
Martin Hickey 2025-12-11 14:30:29 +00:00 committed by GitHub
parent a11f4a81e0
commit f4417f8449
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1036 additions and 15 deletions

View File

@ -0,0 +1,756 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
)
from vllm.v1.outputs import KVConnectorOutput
@pytest.fixture
def mock_lmcache_engine_event():
"""Create a mock event object that mimics what the lmcache engine returns."""
class MockEvent:
def __init__(
self,
block_hashes,
parent_block_hash,
token_ids,
lora_id,
block_size,
medium,
):
self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash
self.token_ids = token_ids
self.lora_id = lora_id
self.block_size = block_size
self.medium = medium
return MockEvent(
block_hashes=["hash1", "hash2"],
parent_block_hash="parent_hash",
token_ids=[1, 2, 3, 4],
lora_id=None,
block_size=16,
medium="GPU",
)
@pytest.fixture
def mock_connector():
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
connector = MagicMock(spec=LMCacheConnectorV1)
connector._kv_cache_events = None
connector._lmcache_engine = MagicMock()
# Make the methods use the real implementation
connector.get_kv_connector_kv_cache_events = (
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
connector, LMCacheConnectorV1
)
)
connector.update_connector_output = (
LMCacheConnectorV1.update_connector_output.__get__(
connector, LMCacheConnectorV1
)
)
connector.take_events = LMCacheConnectorV1.take_events.__get__(
connector, LMCacheConnectorV1
)
return connector
class TestGetKVConnectorKVCacheEvents:
"""Test get_kv_connector_kv_cache_events method."""
def test_returns_none_when_no_events(self, mock_connector):
"""Test that None is returned when lmcache engine has no events."""
mock_connector._lmcache_engine.get_kv_events.return_value = None
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
def test_returns_none_when_empty_list(self, mock_connector):
"""Test that None is returned when lmcache engine returns empty list."""
mock_connector._lmcache_engine.get_kv_events.return_value = []
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
"""Test conversion of a single event from lmcache engine format."""
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
assert result.get_number_of_workers() == 1
events = result.get_all_events()
assert len(events) == 1
assert isinstance(events[0], BlockStored)
assert events[0].block_hashes == ["hash1", "hash2"]
assert events[0].parent_block_hash == "parent_hash"
assert events[0].token_ids == [1, 2, 3, 4]
assert events[0].lora_id is None
assert events[0].block_size == 16
assert events[0].medium == "GPU"
def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format."""
class MockEvent:
def __init__(self, i):
self.block_hashes = [f"hash{i}"]
self.parent_block_hash = f"parent{i}"
self.token_ids = [i]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
converted_events = result.get_all_events()
assert len(converted_events) == 5
for i, event in enumerate(converted_events):
assert isinstance(event, BlockStored)
assert event.block_hashes == [f"hash{i}"]
assert event.parent_block_hash == f"parent{i}"
assert event.token_ids == [i]
def test_preserves_event_attributes(self, mock_connector):
"""Test that all event attributes are correctly preserved."""
class MockEventWithLora:
def __init__(self):
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
self.parent_block_hash = "parent_xyz"
self.token_ids = [100, 200, 300]
self.lora_id = 42
self.block_size = 32
self.medium = "DISK"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
event = events[0]
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
assert event.parent_block_hash == "parent_xyz"
assert event.token_ids == [100, 200, 300]
assert event.lora_id == 42
assert event.block_size == 32
assert event.medium == "DISK"
def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash."""
class MockEventNoParent:
def __init__(self):
self.block_hashes = ["hash1"]
self.parent_block_hash = None
self.token_ids = [1, 2]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
assert events[0].parent_block_hash is None
class TestUpdateConnectorOutput:
"""Test update_connector_output method."""
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that method returns early when kv_cache_events is None."""
connector_output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
self, mock_connector
):
"""Test that method returns early when kv_cache_events is not
LMCacheKVEvents."""
# Create a mock object that is not LMCacheKVEvents
fake_events = MagicMock()
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_sets_kv_cache_events_when_none(self, mock_connector):
"""Test that _kv_cache_events is set when it was None."""
kv_events = LMCacheKVEvents(num_workers=1)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1, 2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is kv_events
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
"""Test that events are added when _kv_cache_events already exists."""
# Set up existing events
existing_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting
mock_connector._kv_cache_events = existing_events
# Create new events to add
new_events = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event2])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Check that events were added
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3 # 2 from existing + 1 from new
assert event1 in all_events
assert event2 in all_events
def test_increments_workers_when_kv_cache_events_already_exists(
self, mock_connector
):
"""Test that worker count is incremented correctly."""
# Set up existing events with 2 workers
existing_events = LMCacheKVEvents(num_workers=2)
mock_connector._kv_cache_events = existing_events
# Create new events from 3 workers
new_events = LMCacheKVEvents(num_workers=3)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Worker count should be 2 + 3 = 5
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
def test_multiple_updates(self, mock_connector):
"""Test multiple consecutive updates."""
# First update
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update
events2 = LMCacheKVEvents(num_workers=2)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Third update
events3 = LMCacheKVEvents(num_workers=1)
event3 = BlockStored(
block_hashes=["hash3"],
parent_block_hash=None,
token_ids=[3],
block_size=16,
lora_id=None,
medium="GPU",
)
events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3)
mock_connector.update_connector_output(output3)
# Check final state
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
def test_updates_with_empty_events(self, mock_connector):
"""Test updating with empty event lists."""
# First update with actual events
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update with empty events
events2 = LMCacheKVEvents(num_workers=2)
# No events added
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Should still have the original event
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 1
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
class TestTakeEvents:
"""Test take_events method."""
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that nothing is yielded when _kv_cache_events is None."""
mock_connector._kv_cache_events = None
events = list(mock_connector.take_events())
assert events == []
def test_yields_events_and_clears(self, mock_connector):
"""Test that events are yielded and then cleared."""
# Set up events
kv_events = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Check that events were yielded
assert len(events) == 2
assert event1 in events
assert event2 in events
# Check that _kv_cache_events was cleared
assert mock_connector._kv_cache_events is None
def test_aggregates_before_yielding(self, mock_connector):
"""Test that events are aggregated before yielding."""
# Set up events from multiple workers
kv_events = LMCacheKVEvents(num_workers=3)
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
uncommon_event = BlockStored(
block_hashes=["hash_uncommon"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# All 3 workers report common_event
kv_events.add_events([common_event])
kv_events.add_events([common_event])
kv_events.add_events([common_event])
# Only 1 worker reports uncommon_event
kv_events.add_events([uncommon_event])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Only the common event should be yielded
assert len(events) == 1
assert events[0] == common_event
def test_multiple_take_events_calls(self, mock_connector):
"""Test calling take_events multiple times."""
# First call with events
kv_events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1
events1 = list(mock_connector.take_events())
assert len(events1) == 1
assert events1[0] == event1
assert mock_connector._kv_cache_events is None
# Second call with no events
events2 = list(mock_connector.take_events())
assert events2 == []
# Third call after adding new events
kv_events2 = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2
events3 = list(mock_connector.take_events())
assert len(events3) == 1
assert events3[0] == event2
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
"""Test that nothing is yielded if aggregation removes all events."""
# Set up events from 2 workers with no common events
kv_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# Worker 1 reports event1
kv_events.add_events([event1])
# Worker 2 reports event2
kv_events.add_events([event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# No common events, so nothing should be yielded
assert events == []
assert mock_connector._kv_cache_events is None
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
"""Test a complete workflow from getting events to taking them."""
# Step 1: Get events from lmcache engine
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is not None
assert len(kv_events.get_all_events()) == 1
# Step 2: Update connector output (simulate receiving from worker)
output1 = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output1)
assert mock_connector._kv_cache_events is not None
# Step 3: Take events
taken_events = list(mock_connector.take_events())
assert len(taken_events) == 1
assert mock_connector._kv_cache_events is None
def test_multiple_workers_workflow(self, mock_connector):
"""Test workflow with multiple workers."""
class MockEvent:
def __init__(self, hash_val):
self.block_hashes = [hash_val]
self.parent_block_hash = None
self.token_ids = [1]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
# Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker1"),
]
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
mock_connector.update_connector_output(output1)
# Worker 2
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker2"),
]
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
mock_connector.update_connector_output(output2)
# Take events (should only get common events)
taken_events = list(mock_connector.take_events())
# With aggregation, only events reported by both workers should be present
# In this case, hash_common was reported by both
event_hashes = [e.block_hashes[0] for e in taken_events]
assert "hash_common" in event_hashes
def test_empty_workflow(self, mock_connector):
"""Test workflow when there are no events at any stage."""
# Get events returns None
mock_connector._lmcache_engine.get_kv_events.return_value = None
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is None
# Update with None
output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(output)
# Take events
taken_events = list(mock_connector.take_events())
assert taken_events == []
assert mock_connector._kv_cache_events is None
def test_repeated_cycles(self, mock_connector):
"""Test multiple cycles of the complete workflow."""
class MockEvent:
def __init__(self, cycle_num):
self.block_hashes = [f"hash_cycle_{cycle_num}"]
self.parent_block_hash = None
self.token_ids = [cycle_num]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
for cycle in range(3):
# Get events
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent(cycle)
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
# Update
output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output)
# Take
taken_events = list(mock_connector.take_events())
# Verify
assert len(taken_events) == 1
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
assert mock_connector._kv_cache_events is None
def test_lmcache_kv_events_aggregation(self):
"""
Test LMCacheKVEvents aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
# Create KVOutputAggregator for 3 workers (simulating TP=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Define common and unique events
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash="parent_common",
token_ids=[1, 2, 3],
block_size=16,
lora_id=None,
medium="GPU",
)
worker1_unique_event = BlockStored(
block_hashes=["hash_worker1"],
parent_block_hash="parent_w1",
token_ids=[4, 5],
block_size=16,
lora_id=None,
medium="GPU",
)
worker2_unique_event = BlockStored(
block_hashes=["hash_worker2"],
parent_block_hash="parent_w2",
token_ids=[6, 7],
block_size=16,
lora_id=None,
medium="GPU",
)
worker3_unique_event = BlockStored(
block_hashes=["hash_worker3"],
parent_block_hash="parent_w3",
token_ids=[8, 9],
block_size=16,
lora_id=None,
medium="GPU",
)
# Create events for each worker
# Worker 0: reports common event and its unique event
worker0_events = LMCacheKVEvents(num_workers=1)
worker0_events.add_events([common_event, worker1_unique_event])
# Worker 1: reports common event and its unique event
worker1_events = LMCacheKVEvents(num_workers=1)
worker1_events.add_events([common_event, worker2_unique_event])
# Worker 2: reports common event and its unique event
worker2_events = LMCacheKVEvents(num_workers=1)
worker2_events.add_events([common_event, worker3_unique_event])
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_events in enumerate(
[worker0_events, worker1_events, worker2_events]
):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2
else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0
else None, # Workers 1,2 finished receiving
kv_cache_events=worker_events,
),
)
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
assert isinstance(kv_cache_events, LMCacheKVEvents)
# After aggregation, events should be combined from all workers
# The aggregator doesn't automatically aggregate events, so we need to call
# aggregate() to get only common events
kv_cache_events.aggregate()
aggregated_events = kv_cache_events.get_all_events()
# Only the common event should remain after aggregation
# because it's the only event reported by all 3 workers
assert len(aggregated_events) == 1
assert aggregated_events[0] == common_event
# Verify the common event properties
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]

View File

@ -5,7 +5,7 @@ import queue
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import Counter, deque
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict from dataclasses import asdict
from itertools import count from itertools import count
@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id: int | None lora_id: int | None
medium: str | None medium: str | None
def __hash__(self) -> int:
return hash(
(
tuple(self.block_hashes),
self.parent_block_hash,
tuple(self.token_ids),
self.block_size,
self.lora_id,
self.medium,
)
)
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash] block_hashes: list[ExternalBlockHash]
medium: str | None medium: str | None
def __hash__(self) -> int:
return hash((tuple(self.block_hashes), self.medium))
class AllBlocksCleared(KVCacheEvent): class AllBlocksCleared(KVCacheEvent):
pass pass
@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
events: list[BlockStored | BlockRemoved | AllBlocksCleared] events: list[BlockStored | BlockRemoved | AllBlocksCleared]
class KVEventAggregator:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""
__slots__ = ("_event_counter", "_num_workers")
def __init__(self, num_workers: int) -> None:
if num_workers <= 0:
raise ValueError("num_workers must be greater than zero.")
self._event_counter: Counter[KVCacheEvent] = Counter()
self._num_workers: int = num_workers
def add_events(self, events: list[KVCacheEvent]) -> None:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
"""
if not isinstance(events, list):
raise TypeError("events must be a list of KVCacheEvent.")
self._event_counter.update(events)
def get_common_events(self) -> list[KVCacheEvent]:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
"""
return [
event
for event, count in self._event_counter.items()
if count == self._num_workers
]
def get_all_events(self) -> list[KVCacheEvent]:
"""
Return all events for all workers.
:return: List of events for all workers.
"""
return list(self._event_counter.elements())
def clear_events(self) -> None:
"""
Clear all tracked events.
"""
self._event_counter.clear()
def increment_workers(self, count: int = 1) -> None:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
"""
if count <= 0:
raise ValueError("count must be positive.")
self._num_workers += count
def reset_workers(self) -> None:
"""
Reset the number of workers to 1.
"""
self._num_workers = 1
def get_number_of_workers(self) -> int:
"""
Return the number of workers.
:return: int number of workers.
"""
return self._num_workers
def __repr__(self) -> str:
return (
f"<KVEventAggregator workers={self._num_workers}, "
f"events={len(self._event_counter)}>"
)
class KVConnectorKVEvents(ABC):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""
@abstractmethod
def add_events(self, events: list[KVCacheEvent]) -> None:
raise NotImplementedError
@abstractmethod
def aggregate(self) -> "KVConnectorKVEvents":
raise NotImplementedError
@abstractmethod
def increment_workers(self, count: int = 1) -> None:
raise NotImplementedError
@abstractmethod
def get_all_events(self) -> list[KVCacheEvent]:
raise NotImplementedError
@abstractmethod
def get_number_of_workers(self) -> int:
raise NotImplementedError
@abstractmethod
def clear_events(self) -> None:
raise NotImplementedError
class EventPublisher(ABC): class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism """Lightweight publisher for EventBatch batches with data parallelism
support. support.

View File

@ -78,6 +78,7 @@ class KVOutputAggregator:
finished_sending = set[str]() finished_sending = set[str]()
finished_recving = set[str]() finished_recving = set[str]()
aggregated_kv_connector_stats = None aggregated_kv_connector_stats = None
combined_kv_cache_events = None
invalid_block_ids = set[int]() invalid_block_ids = set[int]()
for model_runner_output in outputs: for model_runner_output in outputs:
assert model_runner_output is not None assert model_runner_output is not None
@ -119,6 +120,19 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats) aggregated_kv_connector_stats.aggregate(kv_connector_stats)
) )
# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events = kv_output.kv_cache_events
elif kv_cache_events := kv_output.kv_cache_events:
assert isinstance(
combined_kv_cache_events,
type(kv_cache_events),
)
worker_kv_cache_events = kv_cache_events.get_all_events()
combined_kv_cache_events.add_events(worker_kv_cache_events)
combined_kv_cache_events.increment_workers(1)
invalid_block_ids |= kv_output.invalid_block_ids invalid_block_ids |= kv_output.invalid_block_ids
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
@ -129,6 +143,7 @@ class KVOutputAggregator:
finished_sending=finished_sending or None, finished_sending=finished_sending or None,
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None, kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
invalid_block_ids=invalid_block_ids, invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count, expected_finished_count=self._expected_finished_count,
) )

View File

@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics, KVConnectorPromMetrics,
KVConnectorStats, KVConnectorStats,
@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
""" """
Get the KVConnector handshake metadata for this connector. Get the KVConnector handshake metadata for this connector.

View File

@ -1,14 +1,18 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import torch import torch
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
BlockStored,
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorMetadata, KVConnectorMetadata,
@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
@ -26,6 +31,44 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
class LMCacheKVEvents(KVConnectorKVEvents):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)
def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)
def aggregate(self) -> "LMCacheKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self
def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)
def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()
def get_number_of_workers(self) -> int:
return self._aggregator.get_number_of_workers()
def clear_events(self) -> None:
self._aggregator.clear_events()
self._aggregator.reset_workers()
def __repr__(self) -> str:
return f"<LMCacheKVEvents events={self.get_all_events()}>"
class LMCacheConnectorV1(KVConnectorBase_V1): class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__( def __init__(
self, self,
@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
cls = _adapter.LMCacheConnectorV1Impl cls = _adapter.LMCacheConnectorV1Impl
else: else:
logger.info("Initializing latest dev LMCache connector") logger.info("Initializing latest dev LMCache connector")
# lazy import
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
cls = LMCacheConnectorLatestImpl cls = LMCacheConnectorLatestImpl
self._lmcache_engine = cls(vllm_config, role, self) self._lmcache_engine = cls(vllm_config, role, self)
self._kv_cache_events: LMCacheKVEvents | None = None
# ============================== # ==============================
# Worker-side methods # Worker-side methods
# ============================== # ==============================
@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Fallback for older versions that don't support this method # Fallback for older versions that don't support this method
return set() return set()
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
if not events:
return None
blocks: list[BlockStored] = [
BlockStored(
block_hashes=e.block_hashes,
parent_block_hash=e.parent_block_hash,
token_ids=e.token_ids,
lora_id=e.lora_id,
block_size=e.block_size,
medium=e.medium,
)
for e in events
]
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
lmcache_kv_events.add_events(blocks)
return lmcache_kv_events
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
""" """
return self._lmcache_engine.build_connector_meta(scheduler_output) return self._lmcache_engine.build_connector_meta(scheduler_output)
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
return
if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(
kv_cache_events.get_number_of_workers()
)
return
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",
@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
returned by the engine. returned by the engine.
""" """
return self._lmcache_engine.request_finished(request, block_ids) return self._lmcache_engine.request_finished(request, block_ids)
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
kv_cache_events = self._kv_cache_events.get_all_events()
yield from kv_cache_events
self._kv_cache_events.clear_events()
self._kv_cache_events = None

View File

@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids |= c.get_block_ids_with_load_errors() agg_block_ids |= c.get_block_ids_with_load_errors()
return agg_block_ids return agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================

View File

@ -12,9 +12,11 @@ from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_events import KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else: else:
KVConnectorStats = object KVConnectorStats = object
KVConnectorKVEvents = object
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
@ -108,6 +110,7 @@ class KVConnectorOutput:
finished_sending: set[str] | None = None finished_sending: set[str] | None = None
finished_recving: set[str] | None = None finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None kv_connector_stats: KVConnectorStats | None = None
kv_cache_events: KVConnectorKVEvents | None = None
# IDs of externally computed KV blocks that failed to load. # IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them # Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set) invalid_block_ids: set[int] = field(default_factory=set)
@ -123,6 +126,7 @@ class KVConnectorOutput:
not self.finished_sending not self.finished_sending
and not self.finished_recving and not self.finished_recving
and not self.kv_connector_stats and not self.kv_connector_stats
and not self.kv_cache_events
and not self.invalid_block_ids and not self.invalid_block_ids
) )

View File

@ -22,7 +22,6 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
@ -138,16 +137,10 @@ class KVConnectorModelRunnerMixin:
) )
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = ( output.kv_connector_stats = kv_connector.get_kv_connector_stats()
KVConnectorModelRunnerMixin.get_kv_connector_stats() output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
)
kv_connector.clear_connector_metadata()
@staticmethod kv_connector.clear_connector_metadata()
def get_kv_connector_stats() -> KVConnectorStats | None:
if has_kv_transfer_group():
return get_kv_transfer_group().get_kv_connector_stats()
return None
@staticmethod @staticmethod
def use_uniform_kv_cache( def use_uniform_kv_cache(