[V1] Support multiple kv connectors (#17564)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Michael Goin 2025-05-14 19:28:02 -04:00 committed by GitHub
parent 78aa341d12
commit 2142035b51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 424 additions and 2 deletions

View File

@ -0,0 +1,241 @@
# SPDX-License-Identifier: Apache-2.0
import filecmp
import shutil
import tempfile
from collections import defaultdict
from pathlib import Path
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
PROMPT_CONTEXT = "Hi " * 100
PROMPTS = [
PROMPT_CONTEXT + "Hello, my name is",
PROMPT_CONTEXT + "The capital of France is",
]
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(name + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
KVConnectorFactory.register_connector("TestSharedStorageConnector",
TestSharedStorageConnector.__module__,
TestSharedStorageConnector.__name__)
# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
dcmp = filecmp.dircmp(dir1, dir2)
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
print(f"Differences found between {dir1} and {dir2}:")
print(f" Left only: {dcmp.left_only}")
print(f" Right only: {dcmp.right_only}")
print(f" Different files: {dcmp.diff_files}")
return False
for sub_dir in dcmp.common_dirs:
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
return False
return True
def test_multi_shared_storage_connector_consistency():
"""
Tests that MultiConnector with two SharedStorageConnectors saves
identical KV cache data to separate storage locations.
"""
storage_1_path = Path("storage_1/")
storage_2_path = Path("storage_2/")
shutil.rmtree(storage_1_path, ignore_errors=True)
shutil.rmtree(storage_2_path, ignore_errors=True)
storage_1_path.mkdir()
storage_2_path.mkdir()
# Configure MultiConnector with two SharedStorageConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [{
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
"name": "storage1",
}
}, {
"kv_connector": "TestSharedStorageConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
"name": "storage2",
}
}]
},
)
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)
# Run generation - this should trigger saving KV cache
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
# --- Verification ---
# Check that both storage directories were populated
local_subdirs = list(storage_1_path.iterdir())
external_subdirs = list(storage_2_path.iterdir())
assert len(
local_subdirs
) > 0, f"Local storage path {storage_1_path} is empty after generation."
assert len(external_subdirs) > 0, (
f"External storage path {storage_2_path} is empty after generation.")
assert len(local_subdirs) == len(external_subdirs), (
f"Mismatch in number of cache entries: "
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
# The subdirectories should correspond to the prompt hashes
# Since prompts are the same, the hash directories should be the same name
local_subdir_names = sorted([d.name for d in local_subdirs])
external_subdir_names = sorted([d.name for d in external_subdirs])
assert local_subdir_names == external_subdir_names, (
"Cache directory names do not match between local and external storage"
)
# Compare the contents of each corresponding cache directory
for subdir_name in local_subdir_names:
print(f"Comparing contents of cache directory: {subdir_name}")
assert _compare_directories(storage_1_path / subdir_name,
storage_2_path / subdir_name), \
(f"Contents differ for cache directory '{subdir_name}' between "
f"{storage_1_path} and {storage_2_path}")
events = get_connector_events()
# get_num_new_matched_tokens will be called on each connector in turn.
# neither of them have hits so update_state_after_alloc won't be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
]
assert events["storage2"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
]
# Reset prefix cache or else we'll just get the tokens back from there.
llm.reset_prefix_cache()
# Run generation again - this should trigger loading from the first
# connector.
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
events = get_connector_events()
# get_num_new_matched_tokens will return new tokens from the first
# connector so update_state_after_alloc will be called once blocks
# are allocated for the first connector.
# get_num_new_matched_tokens *won't* be called on the second connector
# in this case.
assert events["storage1"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
]
assert events["storage2"][:2] == [
'build_connector_meta', 'bind_connector_metadata'
]
# Delete storage1 connector state
shutil.rmtree(storage_1_path)
# Reset prefix cache or else we'll just get the tokens back from there.
llm.reset_prefix_cache()
# Run generation again - this should trigger loading from the first
# connector.
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
events = get_connector_events()
# get_num_new_matched_tokens will be called for the first connector but it
# won't have a hit so update_state_after_alloc won't be called.
# get_num_new_matched_tokens will also be called on the second connector,
# but it should have a hit so update_state_after_alloc will be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
]
assert events["storage2"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
]
# Clean up
shutil.rmtree(storage_1_path)
shutil.rmtree(storage_2_path)
def get_connector_events() -> dict[str, list[str]]:
# Read in connector events and reset the files.
import glob
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
connector_events = {}
for fname in event_files:
name = fname.split("connector_")[1].split("_events.log")[0]
try:
with open(fname, "r+") as f:
connector_events[name] = [
line.strip() for line in f if line.strip()
]
f.truncate(0)
except Exception as e:
print(f"[ERROR] Could not read connector events for {name}: {e}")
return connector_events

View File

@ -110,3 +110,8 @@ KVConnectorFactory.register_connector(
"NixlConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"NixlConnector")
KVConnectorFactory.register_connector(
"MultiConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
"MultiConnector")

View File

@ -22,7 +22,6 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
@dataclass
class KVConnectorMetadata:
"""
Abstract Metadata used to communicate between the

View File

@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
KVConnectorMetadata):
pass
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
The current logic is:
- Load KV from the first connector that advertises available tokens from
get_num_new_matched_tokens(), based on the order in the config.
- Save to all connectors.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors = []
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
# A mapping from request id to the connector that is assigned to it.
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
# a single connector to load.
self._extra_async_saves: dict[str, int] = {}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
# We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata.
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
for c, cm in zip(self._connectors, connector_metadata):
c.bind_connector_metadata(cm)
def clear_connector_metadata(self) -> None:
for c in self._connectors:
c.clear_connector_metadata()
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
for c in self._connectors:
c.start_load_kv(forward_context, **kwargs)
def wait_for_layer_load(self, layer_name: str) -> None:
for c in self._connectors:
c.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
for c in self._connectors:
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
def wait_for_save(self):
for c in self._connectors:
c.wait_for_save()
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
finished_recving: set[str] = set()
finished_sending: set[str] = set()
for c in self._connectors:
recving, sending = c.get_finished(finished_req_ids)
if not recving and not sending:
continue
# Aggregate finished recving request ids.
finished_recving.update(recving or ())
# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
continue
assert extra_pending > 0
if extra_pending == 1:
del self._extra_async_saves[req_id]
else:
self._extra_async_saves[req_id] = extra_pending - 1
return finished_recving or None, finished_sending or None
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
for c in self._connectors:
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# The first connector that has new matched tokens will be assigned
# to this request.
if toks > 0:
self._requests_to_connector[request.request_id] = c
return toks, load_async
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
# If the request is not assigned to any connector, we do nothing.
if request.request_id not in self._requests_to_connector:
return
# We assume that the request is assigned to only one connector.
c = self._requests_to_connector.pop(request.request_id)
c.update_state_after_alloc(request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
return MultiKVConnectorMetadata(
c.build_connector_meta(scheduler_output) for c in self._connectors)
def request_finished(
self,
request: "Request",
blocks: "KVCacheBlocks",
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
for c in self._connectors:
async_save, txfer_params = c.request_finished(request, blocks)
if async_save:
async_saves += 1
if txfer_params is not None:
if kv_txfer_params is not None:
#TODO we can probably change this to merge the dicts here,
# checking for key clashes.
raise RuntimeError(
"Only one connector can produce KV transfer params")
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1
return async_saves > 0, kv_txfer_params