mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[V1] Support multiple kv connectors (#17564)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
78aa341d12
commit
2142035b51
241
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
241
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
@ -0,0 +1,241 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import filecmp
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
PROMPT_CONTEXT = "Hi " * 100
|
||||
PROMPTS = [
|
||||
PROMPT_CONTEXT + "Hello, my name is",
|
||||
PROMPT_CONTEXT + "The capital of France is",
|
||||
]
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
||||
|
||||
|
||||
class TestSharedStorageConnector(SharedStorageConnector):
|
||||
|
||||
def __init__(self, config: VllmConfig, role):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = SharedStorageConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = tempfile.gettempdir(
|
||||
) + f"/connector_{self.name}_events.log"
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in ("_connector", "call_record", "name", "_event_file",
|
||||
"__class__", "__dict__", "__getattribute__",
|
||||
"__init__"): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
attr = getattr(self._connector, name)
|
||||
|
||||
# Intercept calls to the connector interface and write an event
|
||||
# for each one to a file, which can be read back in the main test proc.
|
||||
if callable(attr):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_record[name] += 1
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(name + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} "
|
||||
f"for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector("TestSharedStorageConnector",
|
||||
TestSharedStorageConnector.__module__,
|
||||
TestSharedStorageConnector.__name__)
|
||||
|
||||
|
||||
# Helper function to compare directories recursively
|
||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||
"""Compares two directories recursively for identical content."""
|
||||
dcmp = filecmp.dircmp(dir1, dir2)
|
||||
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
|
||||
print(f"Differences found between {dir1} and {dir2}:")
|
||||
print(f" Left only: {dcmp.left_only}")
|
||||
print(f" Right only: {dcmp.right_only}")
|
||||
print(f" Different files: {dcmp.diff_files}")
|
||||
return False
|
||||
for sub_dir in dcmp.common_dirs:
|
||||
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_multi_shared_storage_connector_consistency():
|
||||
"""
|
||||
Tests that MultiConnector with two SharedStorageConnectors saves
|
||||
identical KV cache data to separate storage locations.
|
||||
"""
|
||||
storage_1_path = Path("storage_1/")
|
||||
storage_2_path = Path("storage_2/")
|
||||
shutil.rmtree(storage_1_path, ignore_errors=True)
|
||||
shutil.rmtree(storage_2_path, ignore_errors=True)
|
||||
storage_1_path.mkdir()
|
||||
storage_2_path.mkdir()
|
||||
|
||||
# Configure MultiConnector with two SharedStorageConnectors
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [{
|
||||
"kv_connector": "TestSharedStorageConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_1_path),
|
||||
"name": "storage1",
|
||||
}
|
||||
}, {
|
||||
"kv_connector": "TestSharedStorageConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_2_path),
|
||||
"name": "storage2",
|
||||
}
|
||||
}]
|
||||
},
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
# Run generation - this should trigger saving KV cache
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
# --- Verification ---
|
||||
|
||||
# Check that both storage directories were populated
|
||||
local_subdirs = list(storage_1_path.iterdir())
|
||||
external_subdirs = list(storage_2_path.iterdir())
|
||||
|
||||
assert len(
|
||||
local_subdirs
|
||||
) > 0, f"Local storage path {storage_1_path} is empty after generation."
|
||||
assert len(external_subdirs) > 0, (
|
||||
f"External storage path {storage_2_path} is empty after generation.")
|
||||
assert len(local_subdirs) == len(external_subdirs), (
|
||||
f"Mismatch in number of cache entries: "
|
||||
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
|
||||
|
||||
# The subdirectories should correspond to the prompt hashes
|
||||
# Since prompts are the same, the hash directories should be the same name
|
||||
local_subdir_names = sorted([d.name for d in local_subdirs])
|
||||
external_subdir_names = sorted([d.name for d in external_subdirs])
|
||||
assert local_subdir_names == external_subdir_names, (
|
||||
"Cache directory names do not match between local and external storage"
|
||||
)
|
||||
|
||||
# Compare the contents of each corresponding cache directory
|
||||
for subdir_name in local_subdir_names:
|
||||
print(f"Comparing contents of cache directory: {subdir_name}")
|
||||
assert _compare_directories(storage_1_path / subdir_name,
|
||||
storage_2_path / subdir_name), \
|
||||
(f"Contents differ for cache directory '{subdir_name}' between "
|
||||
f"{storage_1_path} and {storage_2_path}")
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called on each connector in turn.
|
||||
# neither of them have hits so update_state_after_alloc won't be called.
|
||||
assert events["storage1"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
]
|
||||
assert events["storage2"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
]
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will return new tokens from the first
|
||||
# connector so update_state_after_alloc will be called once blocks
|
||||
# are allocated for the first connector.
|
||||
# get_num_new_matched_tokens *won't* be called on the second connector
|
||||
# in this case.
|
||||
assert events["storage1"][:4] == [
|
||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
]
|
||||
assert events["storage2"][:2] == [
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
]
|
||||
|
||||
# Delete storage1 connector state
|
||||
shutil.rmtree(storage_1_path)
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called for the first connector but it
|
||||
# won't have a hit so update_state_after_alloc won't be called.
|
||||
# get_num_new_matched_tokens will also be called on the second connector,
|
||||
# but it should have a hit so update_state_after_alloc will be called.
|
||||
assert events["storage1"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
]
|
||||
assert events["storage2"][:4] == [
|
||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
]
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(storage_1_path)
|
||||
shutil.rmtree(storage_2_path)
|
||||
|
||||
|
||||
def get_connector_events() -> dict[str, list[str]]:
|
||||
# Read in connector events and reset the files.
|
||||
import glob
|
||||
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
|
||||
connector_events = {}
|
||||
for fname in event_files:
|
||||
name = fname.split("connector_")[1].split("_events.log")[0]
|
||||
try:
|
||||
with open(fname, "r+") as f:
|
||||
connector_events[name] = [
|
||||
line.strip() for line in f if line.strip()
|
||||
]
|
||||
f.truncate(0)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||
|
||||
return connector_events
|
||||
@ -110,3 +110,8 @@ KVConnectorFactory.register_connector(
|
||||
"NixlConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||
"NixlConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector")
|
||||
|
||||
@ -22,7 +22,6 @@ The class provides the following primitives:
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum):
|
||||
WORKER = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorMetadata:
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
|
||||
178
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
178
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
|
||||
KVConnectorMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
|
||||
# A mapping from request id to the connector that is assigned to it.
|
||||
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
for c, cm in zip(self._connectors, connector_metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
finished_recving: set[str] = set()
|
||||
finished_sending: set[str] = set()
|
||||
for c in self._connectors:
|
||||
recving, sending = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_recving or None, finished_sending or None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
for c in self._connectors:
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if toks > 0:
|
||||
self._requests_to_connector[request.request_id] = c
|
||||
return toks, load_async
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
# If the request is not assigned to any connector, we do nothing.
|
||||
if request.request_id not in self._requests_to_connector:
|
||||
return
|
||||
# We assume that the request is assigned to only one connector.
|
||||
c = self._requests_to_connector.pop(request.request_id)
|
||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||
return MultiKVConnectorMetadata(
|
||||
c.build_connector_meta(scheduler_output) for c in self._connectors)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
#TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params")
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
return async_saves > 0, kv_txfer_params
|
||||
Loading…
x
Reference in New Issue
Block a user