[CI][Nixl] Check kv cache layout during handshake (#22745)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-08-12 21:53:52 +02:00 committed by GitHub
parent 6bd8ebf026
commit 422f22e012
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 3 deletions

View File

@ -419,6 +419,52 @@ class TestNixlHandshake:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
"""
vllm_config = create_vllm_config()
# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096
worker.block_len = worker.slot_size_bytes * worker.block_size
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
else "NHD"
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=worker.block_len,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then

View File

@ -30,6 +30,7 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@ -73,6 +74,7 @@ class NixlAgentMetadata(
num_blocks: int
block_len: int
attn_backend_name: str
kv_cache_layout: str
@dataclass
@ -538,7 +540,9 @@ class NixlConnectorWorker:
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
@ -839,7 +843,8 @@ class NixlConnectorWorker:
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_len=self.block_len,
attn_backend_name=self.backend_name)
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
@ -900,8 +905,7 @@ class NixlConnectorWorker:
self._tp_size[engine_id] = remote_tp_size
else:
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name
remote_agent_name = self.nixl_wrapper.add_remote_agent(
@ -930,6 +934,9 @@ class NixlConnectorWorker:
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
remote_block_size //= 2
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "