mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-23 06:17:02 +08:00
Move KVTransferConfig from config/__init__.py to config/kv_transfer.py (#24434)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
562663a044
commit
3e0d4a3475
@ -9,7 +9,6 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
@ -34,6 +33,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
@ -3210,107 +3210,6 @@ class ObservabilityConfig:
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: Optional[str] = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: Optional[str] = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
|
||||
111
vllm/config/kv_transfer.py
Normal file
111
vllm/config/kv_transfer.py
Normal file
@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: Optional[str] = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: Optional[str] = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: Optional[KVRole] = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: Optional[int] = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: Optional[str] = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
@ -14,7 +14,8 @@ from vllm.logger import init_logger
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
|
||||
@ -15,7 +15,7 @@ import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
|
||||
@ -13,7 +13,7 @@ import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
@ -20,7 +20,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
||||
@ -204,7 +204,7 @@ class LLM:
|
||||
|
||||
if "kv_transfer_config" in kwargs and isinstance(
|
||||
kwargs["kv_transfer_config"], dict):
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
raw_config_dict = kwargs["kv_transfer_config"]
|
||||
try:
|
||||
kwargs["kv_transfer_config"] = KVTransferConfig(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user