mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 08:36:50 +08:00
[V0 deprecation][P/D] Deprecate v0 KVConnectorBase code (1/2) (#21785)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
parent
5ea71ff46f
commit
f4f4e7ef27
@ -749,7 +749,6 @@ steps:
|
|||||||
# this test fails consistently.
|
# this test fails consistently.
|
||||||
# TODO: investigate and fix
|
# TODO: investigate and fix
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||||
|
|
||||||
|
|||||||
@ -1,120 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from subprocess import Popen
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
# Fixture to set up environment variables and teardown servers after tests
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
|
||||||
def setup_servers():
|
|
||||||
if torch.cuda.device_count() < 2:
|
|
||||||
pytest.skip("Skipping test: fewer than 2 GPUs available")
|
|
||||||
|
|
||||||
# Set up environment variables
|
|
||||||
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
|
|
||||||
shell=True).decode().strip()
|
|
||||||
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
|
|
||||||
|
|
||||||
# Start prefill instance
|
|
||||||
prefill_cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
"vllm.entrypoints.openai.api_server",
|
|
||||||
"--model",
|
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
|
||||||
"--port",
|
|
||||||
"8100",
|
|
||||||
"--gpu-memory-utilization",
|
|
||||||
"0.5",
|
|
||||||
"--max-model-len",
|
|
||||||
"1000",
|
|
||||||
"--kv-transfer-config",
|
|
||||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
|
|
||||||
'"kv_rank":0,"kv_parallel_size":2}',
|
|
||||||
]
|
|
||||||
prefill_env = os.environ.copy()
|
|
||||||
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
prefill_proc = Popen(prefill_cmd, env=prefill_env)
|
|
||||||
|
|
||||||
# Start decode instance
|
|
||||||
decode_cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
"vllm.entrypoints.openai.api_server",
|
|
||||||
"--model",
|
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
|
||||||
"--port",
|
|
||||||
"8200",
|
|
||||||
"--gpu-memory-utilization",
|
|
||||||
"0.5",
|
|
||||||
"--max-model-len",
|
|
||||||
"1000",
|
|
||||||
"--kv-transfer-config",
|
|
||||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
|
|
||||||
'"kv_rank":1,"kv_parallel_size":2}',
|
|
||||||
]
|
|
||||||
decode_env = os.environ.copy()
|
|
||||||
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
|
|
||||||
decode_proc = Popen(decode_cmd, env=decode_env)
|
|
||||||
|
|
||||||
# Wait for servers to be ready
|
|
||||||
assert wait_for_server(8100), "Prefill server did not start in time"
|
|
||||||
assert wait_for_server(8200), "Decode server did not start in time"
|
|
||||||
|
|
||||||
# Yield to the test function and handle teardown after tests
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Cleanup: kill the processes
|
|
||||||
prefill_proc.terminate()
|
|
||||||
decode_proc.terminate()
|
|
||||||
|
|
||||||
# Additional cleanup if needed
|
|
||||||
prefill_proc.wait()
|
|
||||||
decode_proc.wait()
|
|
||||||
|
|
||||||
|
|
||||||
# Helper function to wait for server
|
|
||||||
def wait_for_server(port, timeout=240):
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < timeout:
|
|
||||||
try:
|
|
||||||
response = requests.get(f"http://localhost:{port}/v1/completions")
|
|
||||||
if response.status_code in [200, 405]:
|
|
||||||
return True
|
|
||||||
except requests.ConnectionError:
|
|
||||||
time.sleep(1)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Test function to send curl requests and validate responses
|
|
||||||
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
|
|
||||||
def test_disaggregated_prefilling(prompt):
|
|
||||||
# Send to prefill
|
|
||||||
response = requests.post("http://localhost:8100/v1/completions",
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
json={
|
|
||||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_tokens": 1,
|
|
||||||
"temperature": 0
|
|
||||||
})
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
# Send to decode
|
|
||||||
response = requests.post("http://localhost:8200/v1/completions",
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
json={
|
|
||||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_tokens": 10,
|
|
||||||
"temperature": 0
|
|
||||||
})
|
|
||||||
assert response.status_code == 200
|
|
||||||
@ -1,142 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""
|
"""Defines the base type for KV cache connectors."""
|
||||||
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
|
|
||||||
|
|
||||||
The class provides two primary abstract methods:
|
|
||||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
|
||||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
KVConnectorBase = KVConnectorBase_V1
|
||||||
from vllm.config import VllmConfig
|
KVConnectorBaseType = KVConnectorBase_V1
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
|
__all__ = ["KVConnectorBase", "KVConnectorBaseType"]
|
||||||
class KVConnectorBase(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for a KV connector.
|
|
||||||
|
|
||||||
The class provides two primary abstract methods:
|
|
||||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
|
||||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
config: "VllmConfig",
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the buffer and release resources.
|
|
||||||
|
|
||||||
This method is responsible for cleaning up resources related to the
|
|
||||||
connector when it is no longer needed.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def send_kv_caches_and_hidden_states(
|
|
||||||
self,
|
|
||||||
model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
||||||
IntermediateTensors],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Send KV caches and hidden states to the connector.
|
|
||||||
|
|
||||||
This method processes the input tokens, KV caches, and
|
|
||||||
hidden/intermediate states for a given model and sends the data to the
|
|
||||||
decode instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_executable (torch.nn.Module): The model executable containing
|
|
||||||
start and end layer information.
|
|
||||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
|
||||||
metadata from vLLM.
|
|
||||||
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
|
|
||||||
for each layer.
|
|
||||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
|
||||||
IntermediateTensors]):
|
|
||||||
The hidden or intermediate states associated with the tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def recv_kv_caches_and_hidden_states(
|
|
||||||
self, model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor]
|
|
||||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
|
||||||
"""
|
|
||||||
Receive KV caches and hidden states from the connector.
|
|
||||||
|
|
||||||
This method attempts to retrieve KV caches and hidden states for input
|
|
||||||
tokens. If all required KV caches and hidden states are received, it
|
|
||||||
will bypass model input, else it will fall back to normal vLLM model
|
|
||||||
forwarding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_executable (torch.nn.Module):
|
|
||||||
The model executable from vLLM modelrunner.
|
|
||||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
|
||||||
The model input from vLLM modelrunner.
|
|
||||||
kv_caches (list[torch.Tensor]):
|
|
||||||
List of KV caches for each layer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- hidden_or_intermediate_states (torch.Tensor or
|
|
||||||
IntermediateTensors):
|
|
||||||
Concatenated hidden states if all required data is retrieved,
|
|
||||||
otherwise `None`.
|
|
||||||
- bypass_model_exec (bool):
|
|
||||||
Indicates whether the model execution can be skipped (True) or
|
|
||||||
needs to be redone (False).
|
|
||||||
- model_input (ModelInputForGPUWithSamplingMetadata):
|
|
||||||
Optionally adjusted input metadata for re-execution when
|
|
||||||
`bypass_model_exec=False`.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_required_kvcache_layout(
|
|
||||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the required KV cache layout for this connector.
|
|
||||||
Args:
|
|
||||||
vllm_config (VllmConfig): the vllm config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: the required KV cache layout. e.g. HND, or NHD.
|
|
||||||
None if the connector does not require a specific layout.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
|
||||||
|
|||||||
@ -5,14 +5,10 @@ import importlib
|
|||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import KVTransferConfig
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
|
||||||
KVConnectorRole)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .base import KVConnectorBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
@ -20,7 +16,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KVConnectorFactory:
|
class KVConnectorFactory:
|
||||||
_registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {}
|
_registry: dict[str, Callable[[], type[KVConnectorBase]]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_connector(cls, name: str, module_path: str,
|
def register_connector(cls, name: str, module_path: str,
|
||||||
@ -29,28 +25,23 @@ class KVConnectorFactory:
|
|||||||
if name in cls._registry:
|
if name in cls._registry:
|
||||||
raise ValueError(f"Connector '{name}' is already registered.")
|
raise ValueError(f"Connector '{name}' is already registered.")
|
||||||
|
|
||||||
def loader() -> type[KVConnectorBaseType]:
|
def loader() -> type[KVConnectorBase]:
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
cls._registry[name] = loader
|
cls._registry[name] = loader
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_connector_v0(cls, rank: int, local_rank: int,
|
def create_connector(
|
||||||
config: "VllmConfig") -> KVConnectorBase:
|
cls,
|
||||||
if envs.VLLM_USE_V1:
|
config: "VllmConfig",
|
||||||
raise ValueError("Attempting to initialize a V0 Connector, "
|
role: KVConnectorRole,
|
||||||
|
) -> KVConnectorBase:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError("Attempting to initialize a V1 Connector, "
|
||||||
f"but found {envs.VLLM_USE_V1=}")
|
f"but found {envs.VLLM_USE_V1=}")
|
||||||
|
|
||||||
connector_cls = cls.get_connector_class(config.kv_transfer_config)
|
kv_transfer_config = config.kv_transfer_config
|
||||||
assert issubclass(connector_cls, KVConnectorBase)
|
|
||||||
return connector_cls(rank, local_rank, config)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_connector_class(
|
|
||||||
cls, kv_transfer_config: "KVTransferConfig"
|
|
||||||
) -> type[KVConnectorBaseType]:
|
|
||||||
"""Get the connector class by name."""
|
|
||||||
connector_name = kv_transfer_config.kv_connector
|
connector_name = kv_transfer_config.kv_connector
|
||||||
if connector_name in cls._registry:
|
if connector_name in cls._registry:
|
||||||
connector_cls = cls._registry[connector_name]()
|
connector_cls = cls._registry[connector_name]()
|
||||||
@ -61,21 +52,7 @@ class KVConnectorFactory:
|
|||||||
f"Unsupported connector type: {connector_name}")
|
f"Unsupported connector type: {connector_name}")
|
||||||
connector_module = importlib.import_module(connector_module_path)
|
connector_module = importlib.import_module(connector_module_path)
|
||||||
connector_cls = getattr(connector_module, connector_name)
|
connector_cls = getattr(connector_module, connector_name)
|
||||||
return connector_cls
|
assert issubclass(connector_cls, KVConnectorBase)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_connector_v1(
|
|
||||||
cls,
|
|
||||||
config: "VllmConfig",
|
|
||||||
role: KVConnectorRole,
|
|
||||||
) -> KVConnectorBase_V1:
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
raise ValueError("Attempting to initialize a V1 Connector, "
|
|
||||||
f"but found {envs.VLLM_USE_V1=}")
|
|
||||||
|
|
||||||
kv_transfer_config = config.kv_transfer_config
|
|
||||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
|
||||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
|
||||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||||
connector_cls.__name__, kv_transfer_config.engine_id)
|
connector_cls.__name__, kv_transfer_config.engine_id)
|
||||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||||
@ -92,25 +69,6 @@ class KVConnectorFactory:
|
|||||||
# Register various connectors here.
|
# Register various connectors here.
|
||||||
# The registration should not be done in each individual file, as we want to
|
# The registration should not be done in each individual file, as we want to
|
||||||
# only load the files corresponding to the current connector.
|
# only load the files corresponding to the current connector.
|
||||||
KVConnectorFactory.register_connector(
|
|
||||||
"PyNcclConnector",
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
|
||||||
"SimpleConnector")
|
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
|
||||||
"MooncakeConnector",
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
|
||||||
"SimpleConnector")
|
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
|
||||||
"LMCacheConnector",
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
|
|
||||||
"LMCacheConnector")
|
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
|
||||||
"MooncakeStoreConnector",
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
|
||||||
"MooncakeStoreConnector")
|
|
||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
KVConnectorFactory.register_connector(
|
||||||
"SharedStorageConnector",
|
"SharedStorageConnector",
|
||||||
|
|||||||
@ -1,99 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
LMCache KV Cache Connector for Distributed Machine Learning Inference
|
|
||||||
|
|
||||||
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
|
||||||
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
|
|
||||||
(2) offload and share KV caches.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LMCacheConnector(KVConnectorBase):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
config: VllmConfig,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.transfer_config = config.kv_transfer_config
|
|
||||||
self.vllm_config = config
|
|
||||||
|
|
||||||
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
|
||||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
|
||||||
from lmcache.integration.vllm.vllm_adapter import (
|
|
||||||
RetrieveStatus, StoreStatus, init_lmcache_engine,
|
|
||||||
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
|
|
||||||
lmcache_store_kv)
|
|
||||||
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
|
|
||||||
self.transfer_config)
|
|
||||||
|
|
||||||
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
|
|
||||||
self.engine = init_lmcache_engine(config.model_config,
|
|
||||||
config.parallel_config,
|
|
||||||
config.cache_config)
|
|
||||||
self.lmcache_engine_name = ENGINE_NAME
|
|
||||||
self.lmcache_engine_builder = LMCacheEngineBuilder
|
|
||||||
|
|
||||||
self.model_config = config.model_config
|
|
||||||
self.parallel_config = config.parallel_config
|
|
||||||
self.cache_config = config.cache_config
|
|
||||||
self.lmcache_retrieve_kv = lmcache_retrieve_kv
|
|
||||||
self.lmcache_store_kv = lmcache_store_kv
|
|
||||||
self.lmcache_should_retrieve = lmcache_should_retrieve
|
|
||||||
self.lmcache_should_store = lmcache_should_store
|
|
||||||
self.store_status = StoreStatus
|
|
||||||
self.retrieve_status = RetrieveStatus
|
|
||||||
|
|
||||||
def recv_kv_caches_and_hidden_states(
|
|
||||||
self, model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor]
|
|
||||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
|
||||||
|
|
||||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
|
||||||
model_input, bypass_model_exec, hidden_or_intermediate_states =\
|
|
||||||
self.lmcache_retrieve_kv(
|
|
||||||
model_executable, model_input, self.cache_config, kv_caches,
|
|
||||||
retrieve_status)
|
|
||||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
|
||||||
|
|
||||||
def send_kv_caches_and_hidden_states(
|
|
||||||
self,
|
|
||||||
model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
||||||
IntermediateTensors],
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
store_status = self.lmcache_should_store(model_input)
|
|
||||||
self.lmcache_store_kv(
|
|
||||||
self.model_config,
|
|
||||||
self.parallel_config,
|
|
||||||
self.cache_config,
|
|
||||||
model_executable,
|
|
||||||
model_input,
|
|
||||||
kv_caches,
|
|
||||||
store_status,
|
|
||||||
)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)
|
|
||||||
@ -1,203 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
|
||||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
|
||||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
|
||||||
database-style KVStore.
|
|
||||||
"""
|
|
||||||
import hashlib
|
|
||||||
from typing import TYPE_CHECKING, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
||||||
model_aware_kv_ops_helper as kv_helper)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeStoreConnector(KVConnectorBase):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
config: VllmConfig,
|
|
||||||
):
|
|
||||||
self.kv_transfer_config = config.kv_transfer_config
|
|
||||||
self.kv_helper = kv_helper(config)
|
|
||||||
self.local_tp_rank = local_rank
|
|
||||||
|
|
||||||
# Init kv_store
|
|
||||||
if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector":
|
|
||||||
# Check if MOONCAKE_CONFIG_PATH is set
|
|
||||||
import os
|
|
||||||
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
|
|
||||||
|
|
||||||
if not use_mooncake_store:
|
|
||||||
raise ValueError(
|
|
||||||
"To use MooncakeStoreConnector, you need to pass the ENV: "
|
|
||||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
|
||||||
else:
|
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
|
|
||||||
MooncakeStore)
|
|
||||||
logger.info(
|
|
||||||
"Initializing KVStoreConnector under kv_transfer_config %s",
|
|
||||||
self.kv_transfer_config)
|
|
||||||
self.kv_store = MooncakeStore(config)
|
|
||||||
else:
|
|
||||||
logger.error("Can not find %s",
|
|
||||||
self.kv_transfer_config.kv_connector)
|
|
||||||
|
|
||||||
assert self.kv_store is not None
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the buffer and release resources.
|
|
||||||
This method is responsible for cleaning up resources related to the
|
|
||||||
connector when it is no longer needed.
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: This method must be implemented in subclasses.
|
|
||||||
"""
|
|
||||||
self.kv_store.close()
|
|
||||||
|
|
||||||
def send_kv_caches_and_hidden_states(
|
|
||||||
self,
|
|
||||||
model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
||||||
IntermediateTensors],
|
|
||||||
) -> None:
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
|
||||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
|
||||||
start_layer = model_executable.model.start_layer
|
|
||||||
end_layer = model_executable.model.end_layer
|
|
||||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
|
||||||
|
|
||||||
for idx, slen in enumerate(seq_lens):
|
|
||||||
start_pos = sum(seq_lens[:idx])
|
|
||||||
end_pos = start_pos + slen
|
|
||||||
|
|
||||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
||||||
store_key_prefix = self.tensor_hash(current_tokens)
|
|
||||||
keys, values = [], []
|
|
||||||
|
|
||||||
for layer_id in range(start_layer, end_layer):
|
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
|
||||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
|
||||||
kv_cache, num_heads, head_size)
|
|
||||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
|
||||||
|
|
||||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
|
||||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
|
||||||
|
|
||||||
keys = torch.cat(keys, dim=0)
|
|
||||||
values = torch.cat(values, dim=0)
|
|
||||||
kvcache_to_sent = torch.stack((keys, values), dim=0)
|
|
||||||
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
|
|
||||||
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
|
|
||||||
|
|
||||||
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
|
|
||||||
self.kv_store.put(hidden_key,
|
|
||||||
hidden_or_intermediate_states[start_pos:end_pos])
|
|
||||||
|
|
||||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
|
||||||
|
|
||||||
def recv_kv_caches_and_hidden_states(
|
|
||||||
self, model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor]
|
|
||||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
|
||||||
bypass_model_exec = True
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
|
||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
|
||||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
|
||||||
start_layer = model_executable.model.start_layer
|
|
||||||
end_layer = model_executable.model.end_layer
|
|
||||||
hidden_or_intermediate_states_for_one_req = []
|
|
||||||
|
|
||||||
for idx, slen in enumerate(seq_lens):
|
|
||||||
start_pos = sum(seq_lens[:idx])
|
|
||||||
end_pos = start_pos + slen
|
|
||||||
|
|
||||||
if start_pos >= num_prefill_tokens:
|
|
||||||
# This can happen during inflight batching. See:
|
|
||||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
|
||||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
||||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
||||||
logger.warning("You should set --enable_chunked_prefill=False "
|
|
||||||
"and --max_num_batched_tokens "
|
|
||||||
"should be equal to max_seq_len_to_capture")
|
|
||||||
bypass_model_exec = False
|
|
||||||
assert start_pos == num_prefill_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
||||||
|
|
||||||
# get roi for current seq
|
|
||||||
load_key_prefix = self.tensor_hash(current_tokens)
|
|
||||||
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
|
|
||||||
remote_kv = self.kv_store.get(load_kvcache_key)
|
|
||||||
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
|
|
||||||
hidden = self.kv_store.get(hidden_key)
|
|
||||||
|
|
||||||
if remote_kv is None or hidden is None:
|
|
||||||
# didn't find any match.
|
|
||||||
bypass_model_exec = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_computed_tokens = current_tokens.shape[0]
|
|
||||||
|
|
||||||
# update the end position based on how many tokens are cached.
|
|
||||||
end_pos = start_pos + num_computed_tokens
|
|
||||||
|
|
||||||
# call self.kv_store to get kv layer by layer
|
|
||||||
for layer_id in range(start_layer, end_layer):
|
|
||||||
layer = model_executable.model.layers[layer_id]
|
|
||||||
# get kvcache object
|
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
|
||||||
|
|
||||||
# get remote kvcache
|
|
||||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
|
||||||
layer_id]
|
|
||||||
|
|
||||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
|
||||||
remote_v, layer, kv_cache,
|
|
||||||
slot_mapping, start_pos,
|
|
||||||
end_pos)
|
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
|
||||||
|
|
||||||
if not bypass_model_exec:
|
|
||||||
logger.warning(
|
|
||||||
"[rank%d]: Failed to receive all KVs and hidden "
|
|
||||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
|
||||||
hidden_or_intermediate_states = None
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"[rank%d]: Successfully received all KVs and hidden "
|
|
||||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
|
||||||
hidden_or_intermediate_states = torch.cat(
|
|
||||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
|
||||||
|
|
||||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tensor_hash(tensor: torch.Tensor) -> int:
|
|
||||||
"""Calculate the hash value of the tensor."""
|
|
||||||
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
|
||||||
hash_object = hashlib.blake2b(tensor_bytes)
|
|
||||||
hash_hex = hash_object.hexdigest()
|
|
||||||
return int(hash_hex[:16], 16)
|
|
||||||
@ -1,329 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""
|
|
||||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
|
||||||
|
|
||||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
|
||||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
|
||||||
MooncakePipe.
|
|
||||||
|
|
||||||
But the logic can be extended to support other pipe and lookup buffer.
|
|
||||||
"""
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
||||||
model_aware_kv_ops_helper as kv_helper)
|
|
||||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
|
||||||
SimpleBuffer)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleConnector(KVConnectorBase):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
config: VllmConfig,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.config = config.kv_transfer_config
|
|
||||||
self.kv_helper = kv_helper(config)
|
|
||||||
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
|
||||||
PyNcclPipe)
|
|
||||||
logger.info(
|
|
||||||
"Initializing PyNcclConfig under kv_transfer_config %s",
|
|
||||||
self.config)
|
|
||||||
elif self.config.kv_connector == "MooncakeConnector":
|
|
||||||
# Check if MOONCAKE_CONFIG_PATH is set
|
|
||||||
import os
|
|
||||||
use_mooncake_distributed_pipe = os.getenv(
|
|
||||||
'MOONCAKE_CONFIG_PATH') is not None
|
|
||||||
|
|
||||||
if not use_mooncake_distributed_pipe:
|
|
||||||
raise ValueError(
|
|
||||||
"To use MooncakeConnector, you need to pass the ENV: "
|
|
||||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
|
||||||
else:
|
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
|
|
||||||
MooncakePipe)
|
|
||||||
logger.info(
|
|
||||||
"Initializing MooncakeConfig under kv_transfer_config %s",
|
|
||||||
self.config)
|
|
||||||
|
|
||||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
|
||||||
|
|
||||||
self.producer_buffer: Optional[SimpleBuffer] = None
|
|
||||||
self.consumer_buffer: Optional[SimpleBuffer] = None
|
|
||||||
|
|
||||||
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
|
||||||
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
|
||||||
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
|
||||||
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
|
||||||
|
|
||||||
# 2 pipes for every rank in the world
|
|
||||||
port_offset_base = 2 * rank
|
|
||||||
|
|
||||||
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
|
||||||
# and the decode vLLM only uses recv pipe
|
|
||||||
if self.config.is_kv_producer:
|
|
||||||
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
|
||||||
self.producer_data_pipe = PyNcclPipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
port_offset=port_offset_base,
|
|
||||||
)
|
|
||||||
self.producer_signal_pipe = PyNcclPipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
port_offset=port_offset_base + 1,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
elif self.config.kv_connector == "MooncakeConnector":
|
|
||||||
self.producer_data_pipe = MooncakePipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
# We only need to initialize MooncakePipe once
|
|
||||||
self.producer_signal_pipe = self.producer_data_pipe
|
|
||||||
|
|
||||||
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
|
|
||||||
self.producer_data_pipe,
|
|
||||||
self.config.kv_buffer_size)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
# the current vLLM instance is KV consumer, so it needs to connect
|
|
||||||
# its recv pipe to the send pipe of KV producer
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
|
||||||
self.consumer_data_pipe = PyNcclPipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
port_offset=port_offset_base,
|
|
||||||
)
|
|
||||||
self.consumer_signal_pipe = PyNcclPipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
port_offset=port_offset_base + 1,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
elif self.config.kv_connector == "MooncakeConnector":
|
|
||||||
self.consumer_data_pipe = MooncakePipe(
|
|
||||||
local_rank=local_rank,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
self.consumer_signal_pipe = self.consumer_data_pipe
|
|
||||||
|
|
||||||
self.consumer_buffer = SimpleBuffer(
|
|
||||||
self.consumer_signal_pipe,
|
|
||||||
self.consumer_data_pipe,
|
|
||||||
self.config.kv_buffer_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def select(self, input_tokens: Optional[torch.Tensor],
|
|
||||||
roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]:
|
|
||||||
|
|
||||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
|
||||||
"consumer buffer before calling select."
|
|
||||||
return self.consumer_buffer.drop_select(input_tokens, roi)
|
|
||||||
|
|
||||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
|
||||||
key: torch.Tensor, value: torch.Tensor,
|
|
||||||
hidden: torch.Tensor) -> None:
|
|
||||||
|
|
||||||
assert self.producer_buffer is not None, "Please initialize the "\
|
|
||||||
"producer buffer before calling insert."
|
|
||||||
|
|
||||||
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
|
||||||
|
|
||||||
def send_kv_caches_and_hidden_states(
|
|
||||||
self,
|
|
||||||
model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
||||||
IntermediateTensors],
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
|
||||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
|
||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
|
||||||
start_layer = model_executable.model.start_layer
|
|
||||||
end_layer = model_executable.model.end_layer
|
|
||||||
num_heads, head_size = self.kv_helper.get_model_args(model_executable)
|
|
||||||
|
|
||||||
# query_lens contains new KV caches that are added to vLLM.
|
|
||||||
# so we will send them to decode instance
|
|
||||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
|
||||||
for idx, slen in enumerate(seq_lens):
|
|
||||||
start_pos = sum(seq_lens[:idx])
|
|
||||||
end_pos = start_pos + slen
|
|
||||||
|
|
||||||
if start_pos >= num_prefill_tokens:
|
|
||||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
|
||||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
||||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
||||||
logger.warning("You have some decode requests while using "
|
|
||||||
"SimpleConnector. Their KVCache won't be sent.")
|
|
||||||
break
|
|
||||||
|
|
||||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
||||||
|
|
||||||
keys, values = [], []
|
|
||||||
|
|
||||||
for layer_id in range(start_layer, end_layer):
|
|
||||||
kv_cache = kv_caches[layer_id - start_layer]
|
|
||||||
key_cache, value_cache = self.kv_helper.get_kv_from_cache(
|
|
||||||
kv_cache, num_heads, head_size)
|
|
||||||
|
|
||||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
|
||||||
|
|
||||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
|
||||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
|
||||||
|
|
||||||
keys = torch.cat(keys, dim=0)
|
|
||||||
values = torch.cat(values, dim=0)
|
|
||||||
|
|
||||||
self.insert(current_tokens,
|
|
||||||
torch.ones_like(current_tokens,
|
|
||||||
dtype=bool), keys, values,
|
|
||||||
hidden_or_intermediate_states[start_pos:end_pos])
|
|
||||||
|
|
||||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
|
||||||
|
|
||||||
def recv_kv_caches_and_hidden_states(
|
|
||||||
self, model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor]
|
|
||||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
|
||||||
|
|
||||||
# When bypass_model_exec is set to False, it means that at least for one
|
|
||||||
# request its corresponding KV cache or hidden state is missing.
|
|
||||||
# In this case we need to do prefilling to recompute missing KV cache
|
|
||||||
# and hidden states.
|
|
||||||
bypass_model_exec = True
|
|
||||||
|
|
||||||
input_tokens_tensor = model_input.input_tokens
|
|
||||||
seq_lens = model_input.attn_metadata.seq_lens
|
|
||||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
|
||||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
|
||||||
start_layer = model_executable.model.start_layer
|
|
||||||
end_layer = model_executable.model.end_layer
|
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req = []
|
|
||||||
|
|
||||||
input_tokens_list = []
|
|
||||||
num_computed_tokens_list = []
|
|
||||||
start_pos_list = []
|
|
||||||
|
|
||||||
# enumerate different requests
|
|
||||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
|
||||||
for idx, slen in enumerate(seq_lens):
|
|
||||||
start_pos = sum(seq_lens[:idx])
|
|
||||||
end_pos = start_pos + slen
|
|
||||||
|
|
||||||
if start_pos >= num_prefill_tokens:
|
|
||||||
# This can happen during inflight batching. See:
|
|
||||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
|
||||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
||||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
||||||
logger.warning("You should set --enable_chunked_prefill=False "
|
|
||||||
"and --max_num_batched_tokens "
|
|
||||||
"should be equal to --max_seq_len_to_capture")
|
|
||||||
bypass_model_exec = False
|
|
||||||
assert start_pos == num_prefill_tokens
|
|
||||||
break
|
|
||||||
|
|
||||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
|
||||||
num_tokens = slen
|
|
||||||
|
|
||||||
# collecting data for rebuilding the input
|
|
||||||
input_tokens_list.append(current_tokens)
|
|
||||||
start_pos_list.append(start_pos)
|
|
||||||
|
|
||||||
ret = self.select(current_tokens,
|
|
||||||
torch.ones_like(current_tokens, dtype=bool))
|
|
||||||
if ret[0] is None:
|
|
||||||
# didn't find any match.
|
|
||||||
bypass_model_exec = False
|
|
||||||
num_computed_tokens_list.append(0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
roi: torch.Tensor = ret[1]
|
|
||||||
keys: torch.Tensor = ret[2]
|
|
||||||
values: torch.Tensor = ret[3]
|
|
||||||
hidden: torch.Tensor = ret[4]
|
|
||||||
|
|
||||||
num_computed_tokens = roi.shape[0]
|
|
||||||
num_computed_tokens_list.append(num_computed_tokens)
|
|
||||||
|
|
||||||
# check if both KV cache and the hidden states are received
|
|
||||||
# If not, need to redo the forwarding to compute missing states
|
|
||||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
|
||||||
]):
|
|
||||||
bypass_model_exec = False
|
|
||||||
|
|
||||||
# update the end position based on how many tokens are cached.
|
|
||||||
end_pos = start_pos + num_computed_tokens
|
|
||||||
|
|
||||||
# put received KV caches into paged memory
|
|
||||||
for cur_layer in range(start_layer, end_layer):
|
|
||||||
|
|
||||||
layer_id = cur_layer - start_layer
|
|
||||||
kv_cache = kv_caches[layer_id]
|
|
||||||
layer = model_executable.model.layers[cur_layer]
|
|
||||||
|
|
||||||
# get remote kvcache
|
|
||||||
remote_k, remote_v = keys[layer_id], values[layer_id]
|
|
||||||
|
|
||||||
self.kv_helper.put_kv_to_cache(model_executable, remote_k,
|
|
||||||
remote_v, layer, kv_cache,
|
|
||||||
slot_mapping, start_pos,
|
|
||||||
end_pos)
|
|
||||||
|
|
||||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
|
||||||
|
|
||||||
if not bypass_model_exec:
|
|
||||||
# Some of the KV cache is not retrieved
|
|
||||||
# Here we will fall back to normal model forwarding
|
|
||||||
# But optionally you can adjust model_input so that you only do
|
|
||||||
# prefilling on those tokens that are missing KV caches.
|
|
||||||
logger.warning(
|
|
||||||
"[rank%d]: Failed to receive all KVs and hidden "
|
|
||||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
|
||||||
hidden_or_intermediate_states = None
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"[rank%d]: Successfully received all KVs and hidden "
|
|
||||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
|
||||||
hidden_or_intermediate_states = torch.cat(
|
|
||||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
|
||||||
|
|
||||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.producer_data_pipe.close()
|
|
||||||
self.consumer_data_pipe.close()
|
|
||||||
if self.config.kv_connector == "PyNcclConnector":
|
|
||||||
self.producer_signal_pipe.close()
|
|
||||||
self.consumer_signal_pipe.close()
|
|
||||||
elif self.config.kv_connector == "MooncakeConnector":
|
|
||||||
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
|
||||||
# close the data_pipe.
|
|
||||||
pass
|
|
||||||
@ -13,8 +13,8 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorFactory)
|
KVConnectorBase_V1)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||||
|
|
||||||
@ -106,9 +106,8 @@ def get_kv_connector_cache_layout():
|
|||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
kv_config = vllm_config.kv_transfer_config
|
kv_config = vllm_config.kv_transfer_config
|
||||||
if kv_config is not None:
|
if kv_config is not None:
|
||||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
required_kvcache_layout = (
|
||||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
|
||||||
vllm_config)
|
|
||||||
if required_kvcache_layout is not None:
|
if required_kvcache_layout is not None:
|
||||||
return required_kvcache_layout
|
return required_kvcache_layout
|
||||||
logger.info_once("Connectors do not specify a " \
|
logger.info_once("Connectors do not specify a " \
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
temp_config.kv_transfer_config = KVTransferConfig(
|
temp_config.kv_transfer_config = KVTransferConfig(
|
||||||
**ktc, engine_id=engine_id)
|
**ktc, engine_id=engine_id)
|
||||||
self._connectors.append(
|
self._connectors.append(
|
||||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
KVConnectorFactory.create_connector(temp_config, role))
|
||||||
|
|
||||||
# A mapping from request id to the index of the connector chosen to
|
# A mapping from request id to the index of the connector chosen to
|
||||||
# load the request from (if any).
|
# load the request from (if any).
|
||||||
@ -223,9 +223,9 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
for ktc in ktcs:
|
for ktc in ktcs:
|
||||||
kv_transfer_config = KVTransferConfig(**ktc)
|
kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||||
required_kvcache_layout = KVConnectorFactory.get_connector_class(
|
required_kvcache_layout = (
|
||||||
kv_transfer_config).get_required_kvcache_layout(
|
KVConnectorBase_V1.get_required_kvcache_layout(
|
||||||
temp_vllm_config)
|
temp_vllm_config))
|
||||||
if required_kvcache_layout is not None:
|
if required_kvcache_layout is not None:
|
||||||
layouts.add(required_kvcache_layout)
|
layouts.add(required_kvcache_layout)
|
||||||
|
|
||||||
|
|||||||
@ -1,77 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
"""A centralized entrypoint to perform distributed KV cache transfer.
|
|
||||||
|
|
||||||
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
|
||||||
1. `send_kv_caches_and_hidden_states`
|
|
||||||
2. `recv_kv_caches_and_hidden_states
|
|
||||||
"""
|
|
||||||
from typing import TYPE_CHECKING, Union
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|
||||||
KVConnectorFactory)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class KVTransferAgent:
|
|
||||||
"""
|
|
||||||
A class designated for distributed KV transfer
|
|
||||||
|
|
||||||
Target use cases:
|
|
||||||
1. Disaggregated prefill
|
|
||||||
2. Remote KV cache storage
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
rank: int,
|
|
||||||
local_rank: int,
|
|
||||||
config: "VllmConfig",
|
|
||||||
):
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if config.kv_transfer_config is None:
|
|
||||||
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
|
|
||||||
" cannot initialize KVConnector.")
|
|
||||||
|
|
||||||
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
|
||||||
"TransferAgent should only be used when kv_connector is set."
|
|
||||||
|
|
||||||
self.connector = KVConnectorFactory.create_connector_v0(
|
|
||||||
rank, local_rank, config)
|
|
||||||
|
|
||||||
def send_kv_caches_and_hidden_states(
|
|
||||||
self,
|
|
||||||
model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
|
||||||
IntermediateTensors],
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
self.connector.send_kv_caches_and_hidden_states(
|
|
||||||
model_executable, model_input, kv_caches,
|
|
||||||
hidden_or_intermediate_states)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self.connector.close()
|
|
||||||
|
|
||||||
def recv_kv_caches_and_hidden_states(
|
|
||||||
self, model_executable: torch.nn.Module,
|
|
||||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
||||||
kv_caches: list[torch.Tensor]
|
|
||||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
|
||||||
"ModelInputForGPUWithSamplingMetadata"]:
|
|
||||||
|
|
||||||
return self.connector.recv_kv_caches_and_hidden_states(
|
|
||||||
model_executable, model_input, kv_caches)
|
|
||||||
@ -8,7 +8,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
KVConnectorRole)
|
KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import get_world_group
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
|||||||
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
||||||
and _KV_CONNECTOR_AGENT is None):
|
and _KV_CONNECTOR_AGENT is None):
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
||||||
config=vllm_config, role=KVConnectorRole.WORKER)
|
config=vllm_config, role=KVConnectorRole.WORKER)
|
||||||
else:
|
else:
|
||||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
|
raise ValueError("V0 is no longer supported")
|
||||||
rank=get_world_group().rank,
|
|
||||||
local_rank=get_world_group().local_rank,
|
|
||||||
config=vllm_config,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"Multiple KV cache groups are not currently supported "
|
"Multiple KV cache groups are not currently supported "
|
||||||
"with KV connectors")
|
"with KV connectors")
|
||||||
self.connector = KVConnectorFactory.create_connector_v1(
|
self.connector = KVConnectorFactory.create_connector(
|
||||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||||
|
|
||||||
self.kv_event_publisher = EventPublisherFactory.create(
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
|
||||||
@ -31,7 +31,7 @@ class KVConnectorModelRunnerMixin:
|
|||||||
# Update KVConnector with the KVConnector metadata forward().
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
kv_connector = get_kv_transfer_group()
|
kv_connector = get_kv_transfer_group()
|
||||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
assert isinstance(kv_connector, KVConnectorBase)
|
||||||
assert scheduler_output.kv_connector_metadata is not None
|
assert scheduler_output.kv_connector_metadata is not None
|
||||||
kv_connector.bind_connector_metadata(
|
kv_connector.bind_connector_metadata(
|
||||||
scheduler_output.kv_connector_metadata)
|
scheduler_output.kv_connector_metadata)
|
||||||
@ -93,7 +93,7 @@ class KVConnectorModelRunnerMixin:
|
|||||||
|
|
||||||
# Update KVConnector with the KVConnector metadata forward().
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
kv_connector = get_kv_transfer_group()
|
kv_connector = get_kv_transfer_group()
|
||||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
assert isinstance(kv_connector, KVConnectorBase)
|
||||||
assert scheduler_output.kv_connector_metadata is not None
|
assert scheduler_output.kv_connector_metadata is not None
|
||||||
kv_connector.bind_connector_metadata(
|
kv_connector.bind_connector_metadata(
|
||||||
scheduler_output.kv_connector_metadata)
|
scheduler_output.kv_connector_metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user