[NIXL][OOT platform] support nixl_connector with oot platform and other nixl_backend (#25121)

Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Chendi.Xue 2025-09-22 23:17:42 -05:00 committed by yewentao256
parent 675fc471bf
commit 921945c81e
5 changed files with 99 additions and 9 deletions

View File

@ -31,6 +31,12 @@ Now supports 5 types of connectors:
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
```
For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
```bash
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
```
- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):
```bash

View File

@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok
1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`
3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -56,7 +57,7 @@ class FakeNixlWrapper:
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
def register_memory(self, descs, backends) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
@ -855,3 +856,52 @@ def test_register_kv_caches(dist_init):
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"
class FakePlatform(Platform):
device_type: str = "oot"
@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {'oot': ('oot', )}
@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return 'VRAM'
@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
("oot", "VRAM"),
])
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
nixl_memory_type):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
"""
vllm_config = create_vllm_config()
# Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
_NIXL_SUPPORTED_DEVICE)
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501
# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
assert connector.connector_worker.nixl_memory_type == nixl_memory_type

View File

@ -58,6 +58,12 @@ except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
@ -65,6 +71,8 @@ _NIXL_SUPPORTED_DEVICE = {
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata(
@ -448,8 +456,15 @@ class NixlConnectorWorker:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"])
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 and nixl_agent_config is not None else None
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -486,11 +501,15 @@ class NixlConnectorWorker:
# used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
else:
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.")
@ -766,7 +785,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
logger.debug("Done registering descs")
self._registered_descs.append(descs)

View File

@ -604,6 +604,21 @@ class Platform:
return _synced_weight_loader
@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {}
@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return None
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED