mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:35:54 +08:00
[CI][Nixl] Check kv cache layout during handshake (#22745)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
6bd8ebf026
commit
422f22e012
@ -419,6 +419,52 @@ class TestNixlHandshake:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
This test is only relevant for heterogeneous TP.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Mock TP world size to 2 to force heterogeneous TP when
|
||||
# remote_tp_size=1
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||
return_value=2):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
# Metadata with different kv_cache_layout than local worker
|
||||
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
|
||||
else "NHD"
|
||||
meta = NixlAgentMetadata(
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=worker.block_len,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
with pytest.raises(AssertionError):
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@ -73,6 +74,7 @@ class NixlAgentMetadata(
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -538,7 +540,9 @@ class NixlConnectorWorker:
|
||||
attn_backend = backend_name_to_enum(self.backend_name)
|
||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
|
||||
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
@ -839,7 +843,8 @@ class NixlConnectorWorker:
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout)
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
@ -900,8 +905,7 @@ class NixlConnectorWorker:
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
else:
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
# We may eventually enable this after asserting equality in cache
|
||||
# layout and close outputs.
|
||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||
@ -930,6 +934,9 @@ class NixlConnectorWorker:
|
||||
if self._use_flashinfer:
|
||||
# Account for joint KV in FlashInfer.
|
||||
remote_block_size //= 2
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user