[NIXL] Add compatibility checking to NIXL KV connector handshake (#29503)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-12-05 14:52:45 +00:00 committed by GitHub
parent 2c174420f5
commit 949a6a19d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 380 additions and 26 deletions

View File

@ -9,8 +9,10 @@ import textwrap
import time
import uuid
from collections import defaultdict
from unittest.mock import patch
from typing import Any
from unittest.mock import MagicMock, patch
import msgspec
import pytest
import ray
import torch
@ -18,6 +20,7 @@ import torch
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats,
@ -29,7 +32,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlHandshakePayload,
NixlKVConnectorStats,
compute_nixl_compatibility_hash,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init):
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init):
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
@ -403,7 +414,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
@ -651,7 +661,6 @@ class TestNixlHandshake:
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
)
@ -706,7 +715,6 @@ class TestNixlHandshake:
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
block_size=worker.block_size,
)
@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Appease NixlHandshakePayload encoding with some bytes
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
# ensure request appears in get_finished
_, done_recving = connector.get_finished(finished_req_ids=set())
assert request_id in done_recving
@pytest.mark.parametrize(
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat",
[
("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True),
("nixl_connector_version", {}, {"connector_version": 37}, True, True),
("model_name", {"model": "facebook/opt-350m"}, {}, True, True),
("dtype", {"dtype": "bfloat16"}, {}, True, True),
("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True),
("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True),
(
"num_hidden_layers",
{"hf_overrides": {"num_hidden_layers": 24}},
{},
True,
True,
),
("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True),
("block_size", {"block_size": 8}, {}, False, True),
("matching_config", {}, {}, False, True),
("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False),
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
dist_init,
mismatch_type,
config_overrides,
version_override,
should_fail,
enforce_handshake_compat,
):
"""
Test NIXL compatibility hash validation during handshake.
Parameters:
mismatch_type: description of what is being tested
config_overrides: dict of config to override for the remote instance
version_override: version dict e.g. {"vllm_version": "0.6.1"}
should_fail: whether the handshake should fail
enforce_handshake_compat: whether to enforce compatibility checking
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
kv_connector_extra_config={
"enforce_handshake_compat": enforce_handshake_compat
},
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
"block_size": 16,
**config_overrides,
}
remote_vllm_config = create_vllm_config(**remote_config_params)
with contextlib.ExitStack() as stack:
if "vllm_version" in version_override:
stack.enter_context(
patch("vllm.__version__", version_override["vllm_version"])
)
elif "connector_version" in version_override:
stack.enter_context(
patch.object(
nixl_connector,
"NIXL_CONNECTOR_VERSION",
version_override["connector_version"],
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config, decode_worker.backend_name
)
prefill_block_size = config_overrides.get("block_size", 16)
prefill_metadata = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=[4096 * prefill_block_size], # slot_size * block_size
kv_cache_layout="HND",
block_size=prefill_block_size,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,
agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata),
)
# Mock ZMQ socket to return our handshake payload
mock_socket = MagicMock()
mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload)
# Mock add_remote_agent to avoid actual NIXL operations
# Patch zmq_ctx to return our mock socket
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
if should_fail:
with pytest.raises(RuntimeError, match="compatibility hash mismatch"):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
else:
result = decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
# Verify handshake returned agent mapping
assert isinstance(result, dict)
assert len(result) == 1
@pytest.mark.parametrize(
"error_scenario",
[
"handshake_decode_error",
"handshake_validation_error",
"metadata_decode_error",
"metadata_validation_error",
],
)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.
Tests both DecodeError and ValidationError for both decoders:
- NixlHandshakePayload decoder
- NixlAgentMetadata decoder
"""
local_vllm_config = create_vllm_config(
model="facebook/opt-125m",
block_size=16,
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":
msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"})
elif error_scenario == "metadata_decode_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=b"invalid msgpack for metadata",
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)
elif error_scenario == "metadata_validation_error":
valid_handshake = NixlHandshakePayload(
compatibility_hash=decode_worker.compat_hash,
agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}),
)
msg_bytes = msgspec.msgpack.encode(valid_handshake)
else:
raise AssertionError(f"{error_scenario} not a valid scenario")
mock_socket = MagicMock()
mock_socket.recv.return_value = msg_bytes
with (
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
):
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
with pytest.raises(RuntimeError):
decode_worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)

View File

@ -90,13 +90,18 @@ def create_vllm_config(
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
kv_connector_extra_config: dict[str, Any] | None = None,
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
dtype=dtype,
seed=42,
hf_overrides=hf_overrides or {},
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
@ -110,13 +115,14 @@ def create_vllm_config(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
cache_dtype=cache_dtype,
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
return VllmConfig(
scheduler_config=scheduler_config,

View File

@ -59,6 +59,21 @@ Transfer = tuple[int, float] # (xfer_handle, start_time)
EngineId = str
ReqId = str
#
# NIXL Connector Version
#
# Increment this version whenever there is an incompatible change to:
# - NixlAgentMetadata schema
# - kv_transfer_params schema or semantics
# - NIXL transfer protocol or wire format
# - KV cache memory layout or block organization
# - Any other change that breaks P/D interoperability
#
# Version History:
# 1: Initial version with compatibility checking
#
NIXL_CONNECTOR_VERSION: int = 1
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
@ -97,18 +112,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
@dataclass
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
class NixlAgentMetadata:
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str
block_size: int
@dataclass
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
"""
Wrapper for NIXL handshake sent over the wire.
Enables two-phase decoding for graceful compatibility checking:
1. Decode NixlHandshakePayload to get compatibility_hash
2. Compute local hash and compare
3. Only if hashes match, decode agent_metadata_bytes
This prevents decoder errors when NixlAgentMetadata schema is
incompatible, allowing graceful failure with clear error message.
"""
compatibility_hash: str
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
Hash only the factors that affect whether two NIXL instances can
successfully transfer KV cache data.
Factors included:
- vLLM version and NIXL connector version
- Model architecture (name, dtype, KV heads, layers)
- KV cache format (dtype, sliding window)
- Attention backend
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
are validated at runtime in _validate_remote_agent_handshake and are not
included in this hash to support heterogeneous deployments.
Note - the set of factors are likely to evolve significantly over
time to be more or less permissive.
Returns:
SHA-256 hex digest
"""
from vllm import __version__ as vllm_version
from vllm.config.utils import hash_factors
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
factors = {
# Version compatibility
"vllm_version": vllm_version,
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
# Model architecture - affects KV cache shape
"model": model_config.model,
"dtype": str(model_config.dtype),
"num_kv_heads": model_config.get_total_num_kv_heads(),
"head_size": model_config.get_head_size(),
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
}
compat_hash = hash_factors(factors)
logger.info(
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
"cache_dtype=%s, attn_backend=%s)",
compat_hash,
factors["model"],
factors["dtype"],
factors["num_kv_heads"],
factors["cache_dtype"],
attn_backend_name,
)
return compat_hash
@dataclass
class ReqMeta:
local_block_ids: list[int]
@ -396,14 +488,14 @@ class NixlConnectorScheduler:
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlAgentMetadata):
if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError(
"NixlConnectorScheduler expects NixlAgentMetadata for "
"NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
"Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
@ -794,7 +886,7 @@ class NixlConnectorWorker:
self._failed_recv_reqs: set[ReqId] = set()
# Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
self.xfer_handshake_metadata: NixlHandshakePayload | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
@ -829,6 +921,13 @@ class NixlConnectorWorker:
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
@ -877,14 +976,58 @@ class NixlConnectorWorker:
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
handshake_bytes = sock.recv()
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
@ -1175,19 +1318,24 @@ class NixlConnectorWorker:
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
self.xfer_handshake_metadata = NixlAgentMetadata(
agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
)
# Wrap metadata in payload with hash for defensive decoding
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
agent_metadata_bytes=encoder.encode(agent_metadata),
)
def register_local_xfer_handler(
self,
@ -1402,8 +1550,6 @@ class NixlConnectorWorker:
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
# TODO We may eventually want to skip enforcing the same attn backend.
assert nixl_agent_meta.attn_backend_name == self.backend_name
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(