diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 88e1197d703a4..b7a2ca6ca9b24 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -749,7 +749,6 @@ steps: # this test fails consistently. # TODO: investigate and fix - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s models/multimodal/generation/test_maverick.py diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py deleted file mode 100644 index 9f2229cc41dff..0000000000000 --- a/tests/kv_transfer/test_disagg.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import subprocess -import sys -import time -from subprocess import Popen - -import pytest -import requests -import torch - - -# Fixture to set up environment variables and teardown servers after tests -@pytest.fixture(scope="module", autouse=True) -def setup_servers(): - if torch.cuda.device_count() < 2: - pytest.skip("Skipping test: fewer than 2 GPUs available") - - # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", - shell=True).decode().strip() - os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP - - # Start prefill instance - prefill_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8100", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ - '"kv_rank":0,"kv_parallel_size":2}', - ] - prefill_env = os.environ.copy() - prefill_env["CUDA_VISIBLE_DEVICES"] = "0" - prefill_proc = Popen(prefill_cmd, env=prefill_env) - - # Start decode instance - decode_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8200", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ - '"kv_rank":1,"kv_parallel_size":2}', - ] - decode_env = os.environ.copy() - decode_env["CUDA_VISIBLE_DEVICES"] = "1" - decode_proc = Popen(decode_cmd, env=decode_env) - - # Wait for servers to be ready - assert wait_for_server(8100), "Prefill server did not start in time" - assert wait_for_server(8200), "Decode server did not start in time" - - # Yield to the test function and handle teardown after tests - yield - - # Cleanup: kill the processes - prefill_proc.terminate() - decode_proc.terminate() - - # Additional cleanup if needed - prefill_proc.wait() - decode_proc.wait() - - -# Helper function to wait for server -def wait_for_server(port, timeout=240): - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(f"http://localhost:{port}/v1/completions") - if response.status_code in [200, 405]: - return True - except requests.ConnectionError: - time.sleep(1) - return False - - -# Test function to send curl requests and validate responses -@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) -def test_disaggregated_prefilling(prompt): - # Send to prefill - response = requests.post("http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - }) - assert response.status_code == 200 - - # Send to decode - response = requests.post("http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - }) - assert response.status_code == 200 diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 868b227fc8994..011bbb69abb08 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -1,142 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -KVConnectorBase Class for Distributed KV Cache & Hidden State communication - -The class provides two primary abstract methods: -1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states -2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states -""" - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union - -import torch +"""Defines the base type for KV cache connectors.""" from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 -from vllm.sequence import IntermediateTensors -if TYPE_CHECKING: - from vllm.config import VllmConfig - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +KVConnectorBase = KVConnectorBase_V1 +KVConnectorBaseType = KVConnectorBase_V1 - -class KVConnectorBase(ABC): - """ - Abstract base class for a KV connector. - - The class provides two primary abstract methods: - 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states - 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states - """ - - @abstractmethod - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - """ - Send KV caches and hidden states to the connector. - - This method processes the input tokens, KV caches, and - hidden/intermediate states for a given model and sends the data to the - decode instance. - - Args: - model_executable (torch.nn.Module): The model executable containing - start and end layer information. - model_input (ModelInputForGPUWithSamplingMetadata): The input - metadata from vLLM. - kv_caches (list[torch.Tensor]): List of KV caches (keys and values) - for each layer. - hidden_or_intermediate_states (Union[torch.Tensor, - IntermediateTensors]): - The hidden or intermediate states associated with the tokens. - - Returns: - None - - """ - - raise NotImplementedError - - @abstractmethod - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - """ - Receive KV caches and hidden states from the connector. - - This method attempts to retrieve KV caches and hidden states for input - tokens. If all required KV caches and hidden states are received, it - will bypass model input, else it will fall back to normal vLLM model - forwarding. - - Args: - model_executable (torch.nn.Module): - The model executable from vLLM modelrunner. - model_input (ModelInputForGPUWithSamplingMetadata): - The model input from vLLM modelrunner. - kv_caches (list[torch.Tensor]): - List of KV caches for each layer. - - Returns: - - hidden_or_intermediate_states (torch.Tensor or - IntermediateTensors): - Concatenated hidden states if all required data is retrieved, - otherwise `None`. - - bypass_model_exec (bool): - Indicates whether the model execution can be skipped (True) or - needs to be redone (False). - - model_input (ModelInputForGPUWithSamplingMetadata): - Optionally adjusted input metadata for re-execution when - `bypass_model_exec=False`. - - """ - - raise NotImplementedError - - @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: - """ - Get the required KV cache layout for this connector. - Args: - vllm_config (VllmConfig): the vllm config. - - Returns: - str: the required KV cache layout. e.g. HND, or NHD. - None if the connector does not require a specific layout. - """ - return None - - -KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] +__all__ = ["KVConnectorBase", "KVConnectorBaseType"] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index cf7cde2c43771..01673a0d7c876 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,14 +5,10 @@ import importlib from typing import TYPE_CHECKING, Callable import vllm.envs as envs -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger -from .base import KVConnectorBase - if TYPE_CHECKING: from vllm.config import VllmConfig @@ -20,7 +16,7 @@ logger = init_logger(__name__) class KVConnectorFactory: - _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} + _registry: dict[str, Callable[[], type[KVConnectorBase]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -29,28 +25,23 @@ class KVConnectorFactory: if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> type[KVConnectorBaseType]: + def loader() -> type[KVConnectorBase]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector_v0(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: - if envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V0 Connector, " + def create_connector( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " f"but found {envs.VLLM_USE_V1=}") - connector_cls = cls.get_connector_class(config.kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase) - return connector_cls(rank, local_rank, config) - - @classmethod - def get_connector_class( - cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + kv_transfer_config = config.kv_transfer_config connector_name = kv_transfer_config.kv_connector if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() @@ -61,21 +52,7 @@ class KVConnectorFactory: f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) - return connector_cls - - @classmethod - def create_connector_v1( - cls, - config: "VllmConfig", - role: KVConnectorRole, - ) -> KVConnectorBase_V1: - if not envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}") - - kv_transfer_config = config.kv_transfer_config - connector_cls = cls.get_connector_class(kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase_V1) + assert issubclass(connector_cls, KVConnectorBase) logger.info("Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -92,25 +69,6 @@ class KVConnectorFactory: # Register various connectors here. # The registration should not be done in each individual file, as we want to # only load the files corresponding to the current connector. -KVConnectorFactory.register_connector( - "PyNcclConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "LMCacheConnector", - "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", - "LMCacheConnector") - -KVConnectorFactory.register_connector( - "MooncakeStoreConnector", - "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") KVConnectorFactory.register_connector( "SharedStorageConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py deleted file mode 100644 index 78bf3095613a7..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -LMCache KV Cache Connector for Distributed Machine Learning Inference - -The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker -(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; -(2) offload and share KV caches. -""" - -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class LMCacheConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.transfer_config = config.kv_transfer_config - self.vllm_config = config - - from lmcache.experimental.cache_engine import LMCacheEngineBuilder - from lmcache.integration.vllm.utils import ENGINE_NAME - from lmcache.integration.vllm.vllm_adapter import ( - RetrieveStatus, StoreStatus, init_lmcache_engine, - lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, - lmcache_store_kv) - logger.info("Initializing LMCacheConfig under kv_transfer_config %s", - self.transfer_config) - - # TODO (Jiayi): Find model_config, parallel_config, and cache_config - self.engine = init_lmcache_engine(config.model_config, - config.parallel_config, - config.cache_config) - self.lmcache_engine_name = ENGINE_NAME - self.lmcache_engine_builder = LMCacheEngineBuilder - - self.model_config = config.model_config - self.parallel_config = config.parallel_config - self.cache_config = config.cache_config - self.lmcache_retrieve_kv = lmcache_retrieve_kv - self.lmcache_store_kv = lmcache_store_kv - self.lmcache_should_retrieve = lmcache_should_retrieve - self.lmcache_should_store = lmcache_should_store - self.store_status = StoreStatus - self.retrieve_status = RetrieveStatus - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - retrieve_status = self.lmcache_should_retrieve(model_input) - model_input, bypass_model_exec, hidden_or_intermediate_states =\ - self.lmcache_retrieve_kv( - model_executable, model_input, self.cache_config, kv_caches, - retrieve_status) - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - store_status = self.lmcache_should_store(model_input) - self.lmcache_store_kv( - self.model_config, - self.parallel_config, - self.cache_config, - model_executable, - model_input, - kv_caches, - store_status, - ) - - def close(self): - self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py deleted file mode 100644 index 94a7ce91acf17..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ /dev/null @@ -1,203 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -MooncakeStore Connector for Distributed Machine Learning Inference -The MooncakeStoreConnector transfers KV caches between prefill vLLM workers -(KV cache producer) and decode vLLM workers (KV cache consumer) using a -database-style KVStore. -""" -import hashlib -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class MooncakeStoreConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - self.kv_transfer_config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - self.local_tp_rank = local_rank - - # Init kv_store - if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_store: - raise ValueError( - "To use MooncakeStoreConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 - MooncakeStore) - logger.info( - "Initializing KVStoreConnector under kv_transfer_config %s", - self.kv_transfer_config) - self.kv_store = MooncakeStore(config) - else: - logger.error("Can not find %s", - self.kv_transfer_config.kv_connector) - - assert self.kv_store is not None - - def close(self) -> None: - """Close the buffer and release resources. - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - self.kv_store.close() - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - current_tokens = input_tokens_tensor[start_pos:end_pos] - store_key_prefix = self.tensor_hash(current_tokens) - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - kvcache_to_sent = torch.stack((keys, values), dim=0) - store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" - self.kv_store.put(store_kvcache_key, kvcache_to_sent) - - hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" - self.kv_store.put(hidden_key, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - bypass_model_exec = True - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - hidden_or_intermediate_states_for_one_req = [] - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - # get roi for current seq - load_key_prefix = self.tensor_hash(current_tokens) - load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" - remote_kv = self.kv_store.get(load_kvcache_key) - hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" - hidden = self.kv_store.get(hidden_key) - - if remote_kv is None or hidden is None: - # didn't find any match. - bypass_model_exec = False - continue - - num_computed_tokens = current_tokens.shape[0] - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # call self.kv_store to get kv layer by layer - for layer_id in range(start_layer, end_layer): - layer = model_executable.model.layers[layer_id] - # get kvcache object - kv_cache = kv_caches[layer_id - start_layer] - - # get remote kvcache - remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ - layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - @staticmethod - def tensor_hash(tensor: torch.Tensor) -> int: - """Calculate the hash value of the tensor.""" - tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() - hash_object = hashlib.blake2b(tensor_bytes) - hash_hex = hash_object.hexdigest() - return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py deleted file mode 100644 index e7c079e1f115c..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ /dev/null @@ -1,329 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Simple KV Cache Connector for Distributed Machine Learning Inference - -The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache -producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or -MooncakePipe. - -But the logic can be extended to support other pipe and lookup buffer. -""" -from typing import TYPE_CHECKING, Optional, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class SimpleConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - - if self.config.kv_connector == "PyNcclConnector": - from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( - PyNcclPipe) - logger.info( - "Initializing PyNcclConfig under kv_transfer_config %s", - self.config) - elif self.config.kv_connector == "MooncakeConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_distributed_pipe = os.getenv( - 'MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_distributed_pipe: - raise ValueError( - "To use MooncakeConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 - MooncakePipe) - logger.info( - "Initializing MooncakeConfig under kv_transfer_config %s", - self.config) - - self.lookup_buffer_size = self.config.kv_buffer_size - - self.producer_buffer: Optional[SimpleBuffer] = None - self.consumer_buffer: Optional[SimpleBuffer] = None - - self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - - # 2 pipes for every rank in the world - port_offset_base = 2 * rank - - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - if self.config.is_kv_producer: - - if self.config.kv_connector == "PyNcclConnector": - self.producer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.producer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.producer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - # We only need to initialize MooncakePipe once - self.producer_signal_pipe = self.producer_data_pipe - - self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, - self.producer_data_pipe, - self.config.kv_buffer_size) - - else: - - # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producer - if self.config.kv_connector == "PyNcclConnector": - self.consumer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.consumer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.consumer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - self.consumer_signal_pipe = self.consumer_data_pipe - - self.consumer_buffer = SimpleBuffer( - self.consumer_signal_pipe, - self.consumer_data_pipe, - self.config.kv_buffer_size, - ) - - def select(self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: - - assert self.consumer_buffer is not None, "Please initialize the "\ - "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - assert self.producer_buffer is not None, "Please initialize the "\ - "producer buffer before calling insert." - - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You have some decode requests while using " - "SimpleConnector. Their KVCache won't be sent.") - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - # When bypass_model_exec is set to False, it means that at least for one - # request its corresponding KV cache or hidden state is missing. - # In this case we need to do prefilling to recompute missing KV cache - # and hidden states. - bypass_model_exec = True - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to --max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - ret = self.select(current_tokens, - torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. - bypass_model_exec = False - num_computed_tokens_list.append(0) - continue - - roi: torch.Tensor = ret[1] - keys: torch.Tensor = ret[2] - values: torch.Tensor = ret[3] - hidden: torch.Tensor = ret[4] - - num_computed_tokens = roi.shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for cur_layer in range(start_layer, end_layer): - - layer_id = cur_layer - start_layer - kv_cache = kv_caches[layer_id] - layer = model_executable.model.layers[cur_layer] - - # get remote kvcache - remote_k, remote_v = keys[layer_id], values[layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # Here we will fall back to normal model forwarding - # But optionally you can adjust model_input so that you only do - # prefilling on those tokens that are missing KV caches. - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def close(self): - self.producer_data_pipe.close() - self.consumer_data_pipe.close() - if self.config.kv_connector == "PyNcclConnector": - self.producer_signal_pipe.close() - self.consumer_signal_pipe.close() - elif self.config.kv_connector == "MooncakeConnector": - # MooncakePipe reuses data_pipe for signal_pipe, so we only have to - # close the data_pipe. - pass diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1a11cb6d0189a..1da41790f9fb1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -13,8 +13,8 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1) from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -106,9 +106,8 @@ def get_kv_connector_cache_layout(): vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config if kv_config is not None: - connector_cls = KVConnectorFactory.get_connector_class(kv_config) - required_kvcache_layout = connector_cls.get_required_kvcache_layout( - vllm_config) + required_kvcache_layout = ( + KVConnectorBase_V1.get_required_kvcache_layout(vllm_config)) if required_kvcache_layout is not None: return required_kvcache_layout logger.info_once("Connectors do not specify a " \ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 934a03a12ee5e..62a4980bff975 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -52,7 +52,7 @@ class MultiConnector(KVConnectorBase_V1): temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id) self._connectors.append( - KVConnectorFactory.create_connector_v1(temp_config, role)) + KVConnectorFactory.create_connector(temp_config, role)) # A mapping from request id to the index of the connector chosen to # load the request from (if any). @@ -223,9 +223,9 @@ class MultiConnector(KVConnectorBase_V1): for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config - required_kvcache_layout = KVConnectorFactory.get_connector_class( - kv_transfer_config).get_required_kvcache_layout( - temp_vllm_config) + required_kvcache_layout = ( + KVConnectorBase_V1.get_required_kvcache_layout( + temp_vllm_config)) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py deleted file mode 100644 index 8633fdaf59f8b..0000000000000 --- a/vllm/distributed/kv_transfer/kv_connector_agent.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A centralized entrypoint to perform distributed KV cache transfer. - -This implementation is a shim wrapper on two APIs exposed by `kv_connector`: -1. `send_kv_caches_and_hidden_states` -2. `recv_kv_caches_and_hidden_states -""" -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - from vllm.config import VllmConfig - -import torch - -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -logger = init_logger(__name__) - - -class KVTransferAgent: - """ - A class designated for distributed KV transfer - - Target use cases: - 1. Disaggregated prefill - 2. Remote KV cache storage - """ - - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - - self.config = config - - if config.kv_transfer_config is None: - raise ValueError("KVTransferConfig is not set in the VllmConfig," - " cannot initialize KVConnector.") - - assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ - "TransferAgent should only be used when kv_connector is set." - - self.connector = KVConnectorFactory.create_connector_v0( - rank, local_rank, config) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - self.connector.send_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches, - hidden_or_intermediate_states) - - def close(self) -> None: - self.connector.close() - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - return self.connector.recv_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches) diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 60f1d5d8bca75..5e0f64fca220c 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -8,7 +8,6 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) -from vllm.distributed.parallel_state import get_world_group if TYPE_CHECKING: from vllm.config import VllmConfig @@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if (vllm_config.kv_transfer_config.is_kv_transfer_instance and _KV_CONNECTOR_AGENT is None): if envs.VLLM_USE_V1: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( config=vllm_config, role=KVConnectorRole.WORKER) else: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config, - ) + raise ValueError("V0 is no longer supported") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 49a744cfec69a..d39aea1f2d116 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface): assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " "with KV connectors") - self.connector = KVConnectorFactory.create_connector_v1( + self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) self.kv_event_publisher = EventPublisherFactory.create( diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 343befe176797..a03ebe35d8e0a 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, @@ -31,7 +31,7 @@ class KVConnectorModelRunnerMixin: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) + assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None kv_connector.bind_connector_metadata( scheduler_output.kv_connector_metadata) @@ -93,7 +93,7 @@ class KVConnectorModelRunnerMixin: # Update KVConnector with the KVConnector metadata forward(). kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) + assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None kv_connector.bind_connector_metadata( scheduler_output.kv_connector_metadata)