From 3e0d4a34759333ad59069336a3691d8cecc8e0b9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 9 Sep 2025 04:30:32 +0100 Subject: [PATCH] Move `KVTransferConfig` from `config/__init__.py` to `config/kv_transfer.py` (#24434) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 103 +--------------- vllm/config/kv_transfer.py | 111 ++++++++++++++++++ .../kv_transfer/kv_connector/factory.py | 3 +- .../kv_connector/v1/multi_connector.py | 3 +- .../kv_connector/v1/p2p/p2p_nccl_engine.py | 2 +- .../kv_transfer/kv_pipe/mooncake_pipe.py | 2 +- .../kv_transfer/kv_pipe/pynccl_pipe.py | 2 +- vllm/entrypoints/llm.py | 2 +- 8 files changed, 120 insertions(+), 108 deletions(-) create mode 100644 vllm/config/kv_transfer.py diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e3ce1987fe9ee..e84b924ada0a7 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -9,7 +9,6 @@ import hashlib import inspect import json import textwrap -import uuid import warnings from collections.abc import Mapping from contextlib import contextmanager @@ -34,6 +33,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) from vllm.config.kv_events import KVEventsConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy @@ -3210,107 +3210,6 @@ class ObservabilityConfig: self.collect_detailed_traces[0].split(",")) -KVProducer = Literal["kv_producer", "kv_both"] -KVConsumer = Literal["kv_consumer", "kv_both"] -KVRole = Literal[KVProducer, KVConsumer] - - -@config -@dataclass -class KVTransferConfig: - """Configuration for distributed KV cache transfer.""" - - kv_connector: Optional[str] = None - """The KV connector for vLLM to transmit KV caches between vLLM instances. - """ - - engine_id: Optional[str] = None - """The engine id for KV transfers.""" - - kv_buffer_device: Optional[str] = "cuda" - """The device used by kv connector to buffer the KV cache. - Currently only support 'cuda'.""" - - kv_buffer_size: float = 1e9 - """The buffer size for TorchDistributedConnector. Measured in number of - bytes. Recommended value: 1e9 (about 1GB).""" - - kv_role: Optional[KVRole] = None - """Whether this vLLM instance produces, consumes KV cache, or both. Choices - are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - - kv_rank: Optional[int] = None - """The rank of this vLLM instance in the KV cache transfer. Typical value: - 0 for prefill instance, 1 for decode instance. - Currently only 1P1D is supported.""" - - kv_parallel_size: int = 1 - """The number of parallel instances for KV cache transfer. For - P2pNcclConnector, this should be 2.""" - - kv_ip: str = "127.0.0.1" - """The KV connector ip, used to build distributed connection.""" - - kv_port: int = 14579 - """The KV connector port, used to build distributed connection.""" - - kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) - """any extra config that the connector may need.""" - - kv_connector_module_path: Optional[str] = None - """The Python module path to dynamically load the KV connector from. - Only supported in V1.""" - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self) -> None: - if self.engine_id is None: - self.engine_id = str(uuid.uuid4()) - - if self.kv_role is not None and self.kv_role not in get_args(KVRole): - raise ValueError(f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are {get_args(KVRole)}") - - if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - f"is set, supported roles are {get_args(KVRole)}") - - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVRole) - - @property - def is_kv_producer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVProducer) - - @property - def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVConsumer) - - def get_from_extra_config(self, key, default) -> Any: - return self.kv_connector_extra_config.get(key, default) - - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py new file mode 100644 index 0000000000000..9abf4acacfe81 --- /dev/null +++ b/vllm/config/kv_transfer.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import uuid +from dataclasses import field +from typing import Any, Literal, Optional, get_args + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + P2pNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 584fc1d655951..670f9c26b2104 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -14,7 +14,8 @@ from vllm.logger import init_logger # yapf: enable if TYPE_CHECKING: - from vllm.config import KVTransferConfig, VllmConfig + from vllm.config import VllmConfig + from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 65bcb4d93b1e1..edbff4e4340f4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional import torch -from vllm.config import KVTransferConfig, VllmConfig +from vllm.config import VllmConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index dfd95548c4632..fa7cc66ab654d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -15,7 +15,7 @@ import msgpack import torch import zmq -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 0b560d1b3b3ce..2a434e280179e 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -13,7 +13,7 @@ import zmq from safetensors.torch import load as safetensors_load from safetensors.torch import save as safetensors_save -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger from vllm.utils import join_host_port, make_zmq_path, split_host_port diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 09de0b682efca..66120e9a0a1a0 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -20,7 +20,7 @@ from typing import Callable, Optional import torch -from vllm.config import KVTransferConfig +from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d33fd0ec0b494..e6fd61ae1aad9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -204,7 +204,7 @@ class LLM: if "kv_transfer_config" in kwargs and isinstance( kwargs["kv_transfer_config"], dict): - from vllm.config import KVTransferConfig + from vllm.config.kv_transfer import KVTransferConfig raw_config_dict = kwargs["kv_transfer_config"] try: kwargs["kv_transfer_config"] = KVTransferConfig(