mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 20:35:23 +08:00
[NIXL][OOT platform] support nixl_connector with oot platform and other nixl_backend (#25121)
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
675fc471bf
commit
921945c81e
@ -31,6 +31,12 @@ Now supports 5 types of connectors:
|
||||
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
|
||||
```
|
||||
|
||||
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
|
||||
```
|
||||
|
||||
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
|
||||
|
||||
```bash
|
||||
|
||||
@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
|
||||
|
||||
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
|
||||
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
|
||||
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
|
||||
|
||||
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker, NixlKVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
@ -56,7 +57,7 @@ class FakeNixlWrapper:
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
def register_memory(self, descs, backends) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
@ -855,3 +856,52 @@ def test_register_kv_caches(dist_init):
|
||||
assert block_len == expected_block_len, \
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, " \
|
||||
f"got {block_len}"
|
||||
|
||||
|
||||
class FakePlatform(Platform):
|
||||
device_type: str = "oot"
|
||||
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {'oot': ('oot', )}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return 'VRAM'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
|
||||
("oot", "VRAM"),
|
||||
])
|
||||
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
|
||||
nixl_memory_type):
|
||||
"""
|
||||
Test that register_kv_caches() passes the correct memory types from the
|
||||
config to the nixl_wrapper.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
# Override the default memory types in the config
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
_NIXL_SUPPORTED_DEVICE)
|
||||
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
|
||||
|
||||
# Create connector and replace its worker with a fake one for isolation
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Verify get_reg_descs was called with the correct memory_type
|
||||
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
|
||||
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
|
||||
|
||||
@ -58,6 +58,12 @@ except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent_config
|
||||
except ImportError:
|
||||
nixl_agent_config = None
|
||||
logger.warning("NIXL agent config is not available")
|
||||
|
||||
# Supported platforms and types of kv transfer buffer.
|
||||
# {device: tuple of supported kv buffer types}
|
||||
_NIXL_SUPPORTED_DEVICE = {
|
||||
@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = {
|
||||
"tpu": ("cpu", ),
|
||||
"xpu": ("cpu", ),
|
||||
}
|
||||
# support for oot platform by providing mapping in current_platform
|
||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
@ -448,8 +456,15 @@ class NixlConnectorWorker:
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.nixl_backends = \
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"backends", ["UCX"])
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 and nixl_agent_config is not None else None
|
||||
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
@ -486,11 +501,15 @@ class NixlConnectorWorker:
|
||||
# used when device memory can not be registered under nixl
|
||||
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
else:
|
||||
# support for oot platform which can't register nixl memory
|
||||
# type based on kv_buffer_device
|
||||
self.nixl_memory_type = current_platform.get_nixl_memory_type()
|
||||
if self.nixl_memory_type is None:
|
||||
if self.kv_buffer_device == "cuda":
|
||||
self.nixl_memory_type = "VRAM"
|
||||
elif self.kv_buffer_device == "cpu":
|
||||
self.nixl_memory_type = "DRAM"
|
||||
if self.nixl_memory_type is None:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
"is not supported.")
|
||||
@ -766,7 +785,7 @@ class NixlConnectorWorker:
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data,
|
||||
self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
|
||||
@ -604,6 +604,21 @@ class Platform:
|
||||
|
||||
return _synced_weight_loader
|
||||
|
||||
@classmethod
|
||||
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
|
||||
"""
|
||||
Returns a mapping from device_type to a tuple of supported
|
||||
kv_buffer_device for nixl.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_nixl_memory_type(cls) -> Optional[str]:
|
||||
"""
|
||||
Returns the nixl memory type for the current platform.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user