mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:05:00 +08:00
[V1] Support multiple kv connectors (#17564)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
78aa341d12
commit
2142035b51
241
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
241
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import filecmp
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
|
KVConnectorFactory)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||||
|
SharedStorageConnector)
|
||||||
|
|
||||||
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
PROMPT_CONTEXT = "Hi " * 100
|
||||||
|
PROMPTS = [
|
||||||
|
PROMPT_CONTEXT + "Hello, my name is",
|
||||||
|
PROMPT_CONTEXT + "The capital of France is",
|
||||||
|
]
|
||||||
|
|
||||||
|
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedStorageConnector(SharedStorageConnector):
|
||||||
|
|
||||||
|
def __init__(self, config: VllmConfig, role):
|
||||||
|
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||||
|
self._connector = SharedStorageConnector(config, role)
|
||||||
|
self.call_record: dict[str, int] = defaultdict(int)
|
||||||
|
# Use a unique temp file per connector
|
||||||
|
self._event_file = tempfile.gettempdir(
|
||||||
|
) + f"/connector_{self.name}_events.log"
|
||||||
|
# Start with an empty file
|
||||||
|
with open(self._event_file, "w") as _:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
if name in ("_connector", "call_record", "name", "_event_file",
|
||||||
|
"__class__", "__dict__", "__getattribute__",
|
||||||
|
"__init__"): # avoid recursion
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
if not hasattr(self._connector, name):
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
attr = getattr(self._connector, name)
|
||||||
|
|
||||||
|
# Intercept calls to the connector interface and write an event
|
||||||
|
# for each one to a file, which can be read back in the main test proc.
|
||||||
|
if callable(attr):
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
self.call_record[name] += 1
|
||||||
|
# Log the event as a line to the file
|
||||||
|
try:
|
||||||
|
with open(self._event_file, "a") as f:
|
||||||
|
f.write(name + "\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Could not log event {name} "
|
||||||
|
f"for {self.name}: {e}")
|
||||||
|
return attr(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return attr
|
||||||
|
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector("TestSharedStorageConnector",
|
||||||
|
TestSharedStorageConnector.__module__,
|
||||||
|
TestSharedStorageConnector.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to compare directories recursively
|
||||||
|
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||||
|
"""Compares two directories recursively for identical content."""
|
||||||
|
dcmp = filecmp.dircmp(dir1, dir2)
|
||||||
|
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
|
||||||
|
print(f"Differences found between {dir1} and {dir2}:")
|
||||||
|
print(f" Left only: {dcmp.left_only}")
|
||||||
|
print(f" Right only: {dcmp.right_only}")
|
||||||
|
print(f" Different files: {dcmp.diff_files}")
|
||||||
|
return False
|
||||||
|
for sub_dir in dcmp.common_dirs:
|
||||||
|
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_shared_storage_connector_consistency():
|
||||||
|
"""
|
||||||
|
Tests that MultiConnector with two SharedStorageConnectors saves
|
||||||
|
identical KV cache data to separate storage locations.
|
||||||
|
"""
|
||||||
|
storage_1_path = Path("storage_1/")
|
||||||
|
storage_2_path = Path("storage_2/")
|
||||||
|
shutil.rmtree(storage_1_path, ignore_errors=True)
|
||||||
|
shutil.rmtree(storage_2_path, ignore_errors=True)
|
||||||
|
storage_1_path.mkdir()
|
||||||
|
storage_2_path.mkdir()
|
||||||
|
|
||||||
|
# Configure MultiConnector with two SharedStorageConnectors
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="MultiConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={
|
||||||
|
"connectors": [{
|
||||||
|
"kv_connector": "TestSharedStorageConnector",
|
||||||
|
"kv_role": "kv_both",
|
||||||
|
"kv_connector_extra_config": {
|
||||||
|
"shared_storage_path": str(storage_1_path),
|
||||||
|
"name": "storage1",
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
"kv_connector": "TestSharedStorageConnector",
|
||||||
|
"kv_role": "kv_both",
|
||||||
|
"kv_connector_extra_config": {
|
||||||
|
"shared_storage_path": str(storage_2_path),
|
||||||
|
"name": "storage2",
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.5,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
)
|
||||||
|
# Run generation - this should trigger saving KV cache
|
||||||
|
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||||
|
|
||||||
|
# --- Verification ---
|
||||||
|
|
||||||
|
# Check that both storage directories were populated
|
||||||
|
local_subdirs = list(storage_1_path.iterdir())
|
||||||
|
external_subdirs = list(storage_2_path.iterdir())
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
local_subdirs
|
||||||
|
) > 0, f"Local storage path {storage_1_path} is empty after generation."
|
||||||
|
assert len(external_subdirs) > 0, (
|
||||||
|
f"External storage path {storage_2_path} is empty after generation.")
|
||||||
|
assert len(local_subdirs) == len(external_subdirs), (
|
||||||
|
f"Mismatch in number of cache entries: "
|
||||||
|
f"Local={len(local_subdirs)}, External={len(external_subdirs)}")
|
||||||
|
|
||||||
|
# The subdirectories should correspond to the prompt hashes
|
||||||
|
# Since prompts are the same, the hash directories should be the same name
|
||||||
|
local_subdir_names = sorted([d.name for d in local_subdirs])
|
||||||
|
external_subdir_names = sorted([d.name for d in external_subdirs])
|
||||||
|
assert local_subdir_names == external_subdir_names, (
|
||||||
|
"Cache directory names do not match between local and external storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare the contents of each corresponding cache directory
|
||||||
|
for subdir_name in local_subdir_names:
|
||||||
|
print(f"Comparing contents of cache directory: {subdir_name}")
|
||||||
|
assert _compare_directories(storage_1_path / subdir_name,
|
||||||
|
storage_2_path / subdir_name), \
|
||||||
|
(f"Contents differ for cache directory '{subdir_name}' between "
|
||||||
|
f"{storage_1_path} and {storage_2_path}")
|
||||||
|
|
||||||
|
events = get_connector_events()
|
||||||
|
# get_num_new_matched_tokens will be called on each connector in turn.
|
||||||
|
# neither of them have hits so update_state_after_alloc won't be called.
|
||||||
|
assert events["storage1"][:3] == [
|
||||||
|
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||||
|
'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
assert events["storage2"][:3] == [
|
||||||
|
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||||
|
'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||||
|
llm.reset_prefix_cache()
|
||||||
|
|
||||||
|
# Run generation again - this should trigger loading from the first
|
||||||
|
# connector.
|
||||||
|
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||||
|
|
||||||
|
events = get_connector_events()
|
||||||
|
# get_num_new_matched_tokens will return new tokens from the first
|
||||||
|
# connector so update_state_after_alloc will be called once blocks
|
||||||
|
# are allocated for the first connector.
|
||||||
|
# get_num_new_matched_tokens *won't* be called on the second connector
|
||||||
|
# in this case.
|
||||||
|
assert events["storage1"][:4] == [
|
||||||
|
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||||
|
'build_connector_meta', 'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
assert events["storage2"][:2] == [
|
||||||
|
'build_connector_meta', 'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Delete storage1 connector state
|
||||||
|
shutil.rmtree(storage_1_path)
|
||||||
|
|
||||||
|
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||||
|
llm.reset_prefix_cache()
|
||||||
|
|
||||||
|
# Run generation again - this should trigger loading from the first
|
||||||
|
# connector.
|
||||||
|
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||||
|
|
||||||
|
events = get_connector_events()
|
||||||
|
# get_num_new_matched_tokens will be called for the first connector but it
|
||||||
|
# won't have a hit so update_state_after_alloc won't be called.
|
||||||
|
# get_num_new_matched_tokens will also be called on the second connector,
|
||||||
|
# but it should have a hit so update_state_after_alloc will be called.
|
||||||
|
assert events["storage1"][:3] == [
|
||||||
|
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||||
|
'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
assert events["storage2"][:4] == [
|
||||||
|
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||||
|
'build_connector_meta', 'bind_connector_metadata'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
shutil.rmtree(storage_1_path)
|
||||||
|
shutil.rmtree(storage_2_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector_events() -> dict[str, list[str]]:
|
||||||
|
# Read in connector events and reset the files.
|
||||||
|
import glob
|
||||||
|
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
|
||||||
|
connector_events = {}
|
||||||
|
for fname in event_files:
|
||||||
|
name = fname.split("connector_")[1].split("_events.log")[0]
|
||||||
|
try:
|
||||||
|
with open(fname, "r+") as f:
|
||||||
|
connector_events[name] = [
|
||||||
|
line.strip() for line in f if line.strip()
|
||||||
|
]
|
||||||
|
f.truncate(0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||||
|
|
||||||
|
return connector_events
|
||||||
@ -110,3 +110,8 @@ KVConnectorFactory.register_connector(
|
|||||||
"NixlConnector",
|
"NixlConnector",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector",
|
||||||
"NixlConnector")
|
"NixlConnector")
|
||||||
|
|
||||||
|
KVConnectorFactory.register_connector(
|
||||||
|
"MultiConnector",
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||||
|
"MultiConnector")
|
||||||
|
|||||||
@ -22,7 +22,6 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class KVConnectorMetadata:
|
class KVConnectorMetadata:
|
||||||
"""
|
"""
|
||||||
Abstract Metadata used to communicate between the
|
Abstract Metadata used to communicate between the
|
||||||
|
|||||||
178
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
178
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import copy
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
|
KVConnectorFactory)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
|
||||||
|
KVConnectorMetadata):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MultiConnector(KVConnectorBase_V1):
|
||||||
|
"""
|
||||||
|
A wrapper for using multiple KVConnectors at the same time.
|
||||||
|
|
||||||
|
The current logic is:
|
||||||
|
- Load KV from the first connector that advertises available tokens from
|
||||||
|
get_num_new_matched_tokens(), based on the order in the config.
|
||||||
|
- Save to all connectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
|
self._connectors = []
|
||||||
|
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"connectors")
|
||||||
|
assert ktcs is not None
|
||||||
|
for ktc in ktcs:
|
||||||
|
temp_config = copy.copy(vllm_config)
|
||||||
|
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
|
self._connectors.append(
|
||||||
|
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||||
|
|
||||||
|
# A mapping from request id to the connector that is assigned to it.
|
||||||
|
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
|
||||||
|
|
||||||
|
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||||
|
# finished per request. Not needed for async loads since we only allow
|
||||||
|
# a single connector to load.
|
||||||
|
self._extra_async_saves: dict[str, int] = {}
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
for c in self._connectors:
|
||||||
|
c.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
# We must override the base class method here because we need to bind
|
||||||
|
# the metadata to each connector in the order of the connectors in the
|
||||||
|
# MultiKVConnectorMetadata.
|
||||||
|
def bind_connector_metadata(
|
||||||
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
|
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||||
|
for c, cm in zip(self._connectors, connector_metadata):
|
||||||
|
c.bind_connector_metadata(cm)
|
||||||
|
|
||||||
|
def clear_connector_metadata(self) -> None:
|
||||||
|
for c in self._connectors:
|
||||||
|
c.clear_connector_metadata()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
for c in self._connectors:
|
||||||
|
c.start_load_kv(forward_context, **kwargs)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
for c in self._connectors:
|
||||||
|
c.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
for c in self._connectors:
|
||||||
|
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
for c in self._connectors:
|
||||||
|
c.wait_for_save()
|
||||||
|
|
||||||
|
def get_finished(
|
||||||
|
self, finished_req_ids: set[str]
|
||||||
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
|
finished_recving: set[str] = set()
|
||||||
|
finished_sending: set[str] = set()
|
||||||
|
for c in self._connectors:
|
||||||
|
recving, sending = c.get_finished(finished_req_ids)
|
||||||
|
if not recving and not sending:
|
||||||
|
continue
|
||||||
|
# Aggregate finished recving request ids.
|
||||||
|
finished_recving.update(recving or ())
|
||||||
|
# Aggregate finished sending request ids - only include
|
||||||
|
# once we've drained the "extra" count (for cases where
|
||||||
|
# more than one connector is async-saving the same request).
|
||||||
|
for req_id in sending or ():
|
||||||
|
extra_pending = self._extra_async_saves.get(req_id)
|
||||||
|
if extra_pending is None:
|
||||||
|
finished_sending.add(req_id)
|
||||||
|
continue
|
||||||
|
assert extra_pending > 0
|
||||||
|
if extra_pending == 1:
|
||||||
|
del self._extra_async_saves[req_id]
|
||||||
|
else:
|
||||||
|
self._extra_async_saves[req_id] = extra_pending - 1
|
||||||
|
|
||||||
|
return finished_recving or None, finished_sending or None
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Scheduler-side methods
|
||||||
|
# ==============================
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
num_computed_tokens: int,
|
||||||
|
) -> tuple[int, bool]:
|
||||||
|
for c in self._connectors:
|
||||||
|
toks, load_async = c.get_num_new_matched_tokens(
|
||||||
|
request, num_computed_tokens)
|
||||||
|
# The first connector that has new matched tokens will be assigned
|
||||||
|
# to this request.
|
||||||
|
if toks > 0:
|
||||||
|
self._requests_to_connector[request.request_id] = c
|
||||||
|
return toks, load_async
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int):
|
||||||
|
# If the request is not assigned to any connector, we do nothing.
|
||||||
|
if request.request_id not in self._requests_to_connector:
|
||||||
|
return
|
||||||
|
# We assume that the request is assigned to only one connector.
|
||||||
|
c = self._requests_to_connector.pop(request.request_id)
|
||||||
|
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||||
|
return MultiKVConnectorMetadata(
|
||||||
|
c.build_connector_meta(scheduler_output) for c in self._connectors)
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
async_saves = 0
|
||||||
|
kv_txfer_params = None
|
||||||
|
for c in self._connectors:
|
||||||
|
async_save, txfer_params = c.request_finished(request, blocks)
|
||||||
|
if async_save:
|
||||||
|
async_saves += 1
|
||||||
|
if txfer_params is not None:
|
||||||
|
if kv_txfer_params is not None:
|
||||||
|
#TODO we can probably change this to merge the dicts here,
|
||||||
|
# checking for key clashes.
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only one connector can produce KV transfer params")
|
||||||
|
kv_txfer_params = txfer_params
|
||||||
|
if async_saves > 1:
|
||||||
|
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||||
|
return async_saves > 0, kv_txfer_params
|
||||||
Loading…
x
Reference in New Issue
Block a user