mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 23:40:20 +08:00
[NIXL] Add compatibility checking to NIXL KV connector handshake (#29503)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
2c174420f5
commit
949a6a19d2
@ -9,8 +9,10 @@ import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
@ -18,6 +20,7 @@ import torch
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiKVConnectorStats,
|
||||
@ -29,7 +32,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorMetadata,
|
||||
NixlConnectorScheduler,
|
||||
NixlConnectorWorker,
|
||||
NixlHandshakePayload,
|
||||
NixlKVConnectorStats,
|
||||
compute_nixl_compatibility_hash,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
@ -317,13 +322,19 @@ def test_kv_transfer_handshake(dist_init):
|
||||
}
|
||||
prefill_connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Simulate EngineCore initialization that would
|
||||
# gather connector metadata from all workers, the scheduler connector
|
||||
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
|
||||
# where the first key is the dp_rank, the second key is the tp_rank.
|
||||
metadata = {0: prefill_connector.get_handshake_metadata()}
|
||||
# Simulate EngineCore initialization that would gather connector
|
||||
# metadata from all workers
|
||||
metadata = prefill_connector.get_handshake_metadata()
|
||||
|
||||
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
|
||||
|
||||
# The scheduler connector expects metadata to be in
|
||||
# dict[int, KVConnectorHandshakeMetadata], where the first key is
|
||||
# the dp_rank, the second key is the tp_rank.
|
||||
scheduler_connector = scheduler.get_kv_connector()
|
||||
scheduler_connector.set_xfer_handshake_metadata(metadata)
|
||||
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
|
||||
|
||||
# Simulate a request that finishes prefill, which returns
|
||||
# corresponding NixlConnectorMetadata for decode instance.
|
||||
@ -362,9 +373,9 @@ def test_kv_transfer_handshake(dist_init):
|
||||
)
|
||||
|
||||
received_metadata = mock_add_remote_agent.call_args.args
|
||||
assert received_metadata[0] == expected_agent_metadata
|
||||
assert received_metadata[1] == 0 # remote_tp_rank
|
||||
assert received_metadata[2] == 1 # remote_tp_size
|
||||
assert metadata[0] == received_metadata[0]
|
||||
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
scheduler_connector.shutdown()
|
||||
@ -403,7 +414,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
@ -651,7 +661,6 @@ class TestNixlHandshake:
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
block_size=worker.block_size,
|
||||
)
|
||||
@ -706,7 +715,6 @@ class TestNixlHandshake:
|
||||
num_blocks=1,
|
||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout="HND",
|
||||
block_size=worker.block_size,
|
||||
)
|
||||
@ -1168,6 +1176,9 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
mock_wrapper_instance = mock_nixl_wrapper.return_value
|
||||
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
|
||||
|
||||
# Appease NixlHandshakePayload encoding with some bytes
|
||||
mock_wrapper_instance.get_agent_metadata.return_value = b"fake_agent_metadata"
|
||||
|
||||
# Reassure the shutdown() check that the thread is terminated
|
||||
mock_thread.return_value.is_alive.return_value = False
|
||||
|
||||
@ -1534,3 +1545,194 @@ def test_transfer_setup_failure_returns_finished(dist_init):
|
||||
# ensure request appears in get_finished
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
assert request_id in done_recving
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mismatch_type,config_overrides,version_override,should_fail,enforce_handshake_compat",
|
||||
[
|
||||
("vllm_version", {}, {"vllm_version": "0.6.1"}, True, True),
|
||||
("nixl_connector_version", {}, {"connector_version": 37}, True, True),
|
||||
("model_name", {"model": "facebook/opt-350m"}, {}, True, True),
|
||||
("dtype", {"dtype": "bfloat16"}, {}, True, True),
|
||||
("cache_dtype", {"cache_dtype": "fp8"}, {}, True, True),
|
||||
("num_kv_heads", {"hf_overrides": {"num_key_value_heads": 8}}, {}, True, True),
|
||||
(
|
||||
"num_hidden_layers",
|
||||
{"hf_overrides": {"num_hidden_layers": 24}},
|
||||
{},
|
||||
True,
|
||||
True,
|
||||
),
|
||||
("hidden_size", {"hf_overrides": {"hidden_size": 1536}}, {}, True, True),
|
||||
("block_size", {"block_size": 8}, {}, False, True),
|
||||
("matching_config", {}, {}, False, True),
|
||||
("escape_hatch", {"model": "facebook/opt-350m"}, {}, False, False),
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_compatibility_hash_validation(
|
||||
dist_init,
|
||||
mismatch_type,
|
||||
config_overrides,
|
||||
version_override,
|
||||
should_fail,
|
||||
enforce_handshake_compat,
|
||||
):
|
||||
"""
|
||||
Test NIXL compatibility hash validation during handshake.
|
||||
|
||||
Parameters:
|
||||
mismatch_type: description of what is being tested
|
||||
config_overrides: dict of config to override for the remote instance
|
||||
version_override: version dict e.g. {"vllm_version": "0.6.1"}
|
||||
should_fail: whether the handshake should fail
|
||||
enforce_handshake_compat: whether to enforce compatibility checking
|
||||
"""
|
||||
local_vllm_config = create_vllm_config(
|
||||
model="facebook/opt-125m",
|
||||
block_size=16,
|
||||
kv_connector_extra_config={
|
||||
"enforce_handshake_compat": enforce_handshake_compat
|
||||
},
|
||||
)
|
||||
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
|
||||
remote_config_params: dict[str, Any] = {
|
||||
"model": "facebook/opt-125m",
|
||||
"block_size": 16,
|
||||
**config_overrides,
|
||||
}
|
||||
remote_vllm_config = create_vllm_config(**remote_config_params)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if "vllm_version" in version_override:
|
||||
stack.enter_context(
|
||||
patch("vllm.__version__", version_override["vllm_version"])
|
||||
)
|
||||
elif "connector_version" in version_override:
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
nixl_connector,
|
||||
"NIXL_CONNECTOR_VERSION",
|
||||
version_override["connector_version"],
|
||||
)
|
||||
)
|
||||
remote_hash = compute_nixl_compatibility_hash(
|
||||
remote_vllm_config, decode_worker.backend_name
|
||||
)
|
||||
|
||||
prefill_block_size = config_overrides.get("block_size", 16)
|
||||
prefill_metadata = NixlAgentMetadata(
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=[4096 * prefill_block_size], # slot_size * block_size
|
||||
kv_cache_layout="HND",
|
||||
block_size=prefill_block_size,
|
||||
)
|
||||
handshake_payload = NixlHandshakePayload(
|
||||
compatibility_hash=remote_hash,
|
||||
agent_metadata_bytes=msgspec.msgpack.encode(prefill_metadata),
|
||||
)
|
||||
|
||||
# Mock ZMQ socket to return our handshake payload
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv.return_value = msgspec.msgpack.encode(handshake_payload)
|
||||
|
||||
# Mock add_remote_agent to avoid actual NIXL operations
|
||||
# Patch zmq_ctx to return our mock socket
|
||||
with (
|
||||
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
|
||||
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
|
||||
):
|
||||
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
|
||||
|
||||
if should_fail:
|
||||
with pytest.raises(RuntimeError, match="compatibility hash mismatch"):
|
||||
decode_worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=1,
|
||||
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
else:
|
||||
result = decode_worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=1,
|
||||
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
# Verify handshake returned agent mapping
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_scenario",
|
||||
[
|
||||
"handshake_decode_error",
|
||||
"handshake_validation_error",
|
||||
"metadata_decode_error",
|
||||
"metadata_validation_error",
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_decode_errors(dist_init, error_scenario):
|
||||
"""
|
||||
Test that msgspec decode errors are properly handled during handshake.
|
||||
|
||||
Tests both DecodeError and ValidationError for both decoders:
|
||||
- NixlHandshakePayload decoder
|
||||
- NixlAgentMetadata decoder
|
||||
"""
|
||||
local_vllm_config = create_vllm_config(
|
||||
model="facebook/opt-125m",
|
||||
block_size=16,
|
||||
)
|
||||
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
|
||||
if error_scenario == "handshake_decode_error":
|
||||
msg_bytes = b"this is not valid msgpack data"
|
||||
elif error_scenario == "handshake_validation_error":
|
||||
msg_bytes = msgspec.msgpack.encode({"wrong_field": "value"})
|
||||
elif error_scenario == "metadata_decode_error":
|
||||
valid_handshake = NixlHandshakePayload(
|
||||
compatibility_hash=decode_worker.compat_hash,
|
||||
agent_metadata_bytes=b"invalid msgpack for metadata",
|
||||
)
|
||||
msg_bytes = msgspec.msgpack.encode(valid_handshake)
|
||||
|
||||
elif error_scenario == "metadata_validation_error":
|
||||
valid_handshake = NixlHandshakePayload(
|
||||
compatibility_hash=decode_worker.compat_hash,
|
||||
agent_metadata_bytes=msgspec.msgpack.encode({"missing": "fields"}),
|
||||
)
|
||||
msg_bytes = msgspec.msgpack.encode(valid_handshake)
|
||||
else:
|
||||
raise AssertionError(f"{error_scenario} not a valid scenario")
|
||||
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv.return_value = msg_bytes
|
||||
with (
|
||||
patch.object(decode_worker, "add_remote_agent", return_value="fake_agent"),
|
||||
patch.object(nixl_connector, "zmq_ctx") as mock_zmq_ctx,
|
||||
):
|
||||
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
decode_worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=1,
|
||||
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
|
||||
@ -90,13 +90,18 @@ def create_vllm_config(
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
kv_connector_extra_config: dict[str, Any] | None = None,
|
||||
dtype: str = "float16",
|
||||
cache_dtype: str = "auto",
|
||||
hf_overrides: dict[str, Any] | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
dtype=dtype,
|
||||
seed=42,
|
||||
hf_overrides=hf_overrides or {},
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
@ -110,13 +115,14 @@ def create_vllm_config(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
cache_dtype=cache_dtype,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
|
||||
@ -59,6 +59,21 @@ Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
#
|
||||
# NIXL Connector Version
|
||||
#
|
||||
# Increment this version whenever there is an incompatible change to:
|
||||
# - NixlAgentMetadata schema
|
||||
# - kv_transfer_params schema or semantics
|
||||
# - NIXL transfer protocol or wire format
|
||||
# - KV cache memory layout or block organization
|
||||
# - Any other change that breaks P/D interoperability
|
||||
#
|
||||
# Version History:
|
||||
# 1: Initial version with compatibility checking
|
||||
#
|
||||
NIXL_CONNECTOR_VERSION: int = 1
|
||||
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -97,18 +112,95 @@ _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
class NixlAgentMetadata:
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
device_id: int
|
||||
num_blocks: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
block_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlHandshakePayload(KVConnectorHandshakeMetadata):
|
||||
"""
|
||||
Wrapper for NIXL handshake sent over the wire.
|
||||
|
||||
Enables two-phase decoding for graceful compatibility checking:
|
||||
1. Decode NixlHandshakePayload to get compatibility_hash
|
||||
2. Compute local hash and compare
|
||||
3. Only if hashes match, decode agent_metadata_bytes
|
||||
|
||||
This prevents decoder errors when NixlAgentMetadata schema is
|
||||
incompatible, allowing graceful failure with clear error message.
|
||||
"""
|
||||
|
||||
compatibility_hash: str
|
||||
agent_metadata_bytes: bytes # NixlAgentMetadata encoded
|
||||
|
||||
|
||||
def compute_nixl_compatibility_hash(
|
||||
vllm_config: VllmConfig, attn_backend_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Compute compatibility hash for NIXL KV transfer.
|
||||
|
||||
Hash only the factors that affect whether two NIXL instances can
|
||||
successfully transfer KV cache data.
|
||||
|
||||
Factors included:
|
||||
- vLLM version and NIXL connector version
|
||||
- Model architecture (name, dtype, KV heads, layers)
|
||||
- KV cache format (dtype, sliding window)
|
||||
- Attention backend
|
||||
|
||||
Note: Factors like tensor_parallel_size, block_size, and kv_cache_layout
|
||||
are validated at runtime in _validate_remote_agent_handshake and are not
|
||||
included in this hash to support heterogeneous deployments.
|
||||
|
||||
Note - the set of factors are likely to evolve significantly over
|
||||
time to be more or less permissive.
|
||||
|
||||
Returns:
|
||||
SHA-256 hex digest
|
||||
"""
|
||||
from vllm import __version__ as vllm_version
|
||||
from vllm.config.utils import hash_factors
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
factors = {
|
||||
# Version compatibility
|
||||
"vllm_version": vllm_version,
|
||||
"nixl_connector_version": NIXL_CONNECTOR_VERSION,
|
||||
# Model architecture - affects KV cache shape
|
||||
"model": model_config.model,
|
||||
"dtype": str(model_config.dtype),
|
||||
"num_kv_heads": model_config.get_total_num_kv_heads(),
|
||||
"head_size": model_config.get_head_size(),
|
||||
"num_hidden_layers": model_config.get_total_num_hidden_layers(),
|
||||
# Attention backend and KV cache dtype affect memory layout
|
||||
"attn_backend_name": attn_backend_name,
|
||||
"cache_dtype": str(cache_config.cache_dtype),
|
||||
}
|
||||
|
||||
compat_hash = hash_factors(factors)
|
||||
logger.info(
|
||||
"NIXL compatibility hash: %s (model=%s, dtype=%s, num_kv_heads=%d, "
|
||||
"cache_dtype=%s, attn_backend=%s)",
|
||||
compat_hash,
|
||||
factors["model"],
|
||||
factors["dtype"],
|
||||
factors["num_kv_heads"],
|
||||
factors["cache_dtype"],
|
||||
attn_backend_name,
|
||||
)
|
||||
return compat_hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
@ -396,14 +488,14 @@ class NixlConnectorScheduler:
|
||||
encoded_data: dict[int, bytes] = {}
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
for tp_rank, rank_metadata in metadata.items():
|
||||
if not isinstance(rank_metadata, NixlAgentMetadata):
|
||||
if not isinstance(rank_metadata, NixlHandshakePayload):
|
||||
raise ValueError(
|
||||
"NixlConnectorScheduler expects NixlAgentMetadata for "
|
||||
"NixlConnectorScheduler expects NixlHandshakePayload for "
|
||||
"handshake metadata."
|
||||
)
|
||||
encoded_data[tp_rank] = encoder.encode(rank_metadata)
|
||||
logger.debug(
|
||||
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
|
||||
"Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
|
||||
tp_rank,
|
||||
str(len(encoded_data[tp_rank])),
|
||||
)
|
||||
@ -794,7 +886,7 @@ class NixlConnectorWorker:
|
||||
self._failed_recv_reqs: set[ReqId] = set()
|
||||
|
||||
# Handshake metadata of this worker for NIXL transfers.
|
||||
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
|
||||
self.xfer_handshake_metadata: NixlHandshakePayload | None = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||
@ -829,6 +921,13 @@ class NixlConnectorWorker:
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self.compat_hash = compute_nixl_compatibility_hash(
|
||||
self.vllm_config, self.backend_name
|
||||
)
|
||||
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
|
||||
"enforce_handshake_compat", True
|
||||
)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
@ -877,14 +976,58 @@ class NixlConnectorWorker:
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(msg)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
handshake_bytes = sock.recv()
|
||||
|
||||
# Decode handshake payload to get compatibility hash
|
||||
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
|
||||
try:
|
||||
handshake_payload = handshake_decoder.decode(handshake_bytes)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlHandshakePayload. This likely indicates "
|
||||
f"an incompatibility between connector version. Error: {e}"
|
||||
) from e
|
||||
|
||||
got_metadata_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
||||
)
|
||||
|
||||
# Check compatibility hash BEFORE decoding agent metadata
|
||||
if (
|
||||
self.enforce_compat_hash
|
||||
and handshake_payload.compatibility_hash != self.compat_hash
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"NIXL compatibility hash mismatch. "
|
||||
f"Local: {self.compat_hash}, "
|
||||
f"Remote: {handshake_payload.compatibility_hash}. "
|
||||
f"Prefill and decode instances have incompatible configurations. "
|
||||
f"This may be due to: different vLLM versions, models, dtypes, "
|
||||
f"KV cache layouts, attention backends, etc. "
|
||||
f"Both instances must use identical configurations."
|
||||
f"Disable this check using "
|
||||
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
|
||||
f'{{"enforce_handshake_compat": false}}}}\''
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"NIXL compatibility check passed (hash: %s)",
|
||||
handshake_payload.compatibility_hash,
|
||||
)
|
||||
|
||||
# Decode agent metadata
|
||||
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
try:
|
||||
metadata = metadata_decoder.decode(
|
||||
handshake_payload.agent_metadata_bytes
|
||||
)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
# This should not happen if hash matched
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlAgentMetadata. Error: {e}"
|
||||
) from e
|
||||
|
||||
# Ensure engine id matches.
|
||||
if metadata.engine_id != expected_engine_id:
|
||||
raise RuntimeError(
|
||||
@ -1175,19 +1318,24 @@ class NixlConnectorWorker:
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
self.xfer_handshake_metadata = NixlAgentMetadata(
|
||||
agent_metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
device_id=self.device_id,
|
||||
num_blocks=self.num_blocks,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
block_size=self.block_size,
|
||||
)
|
||||
# Wrap metadata in payload with hash for defensive decoding
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
self.xfer_handshake_metadata = NixlHandshakePayload(
|
||||
compatibility_hash=self.compat_hash,
|
||||
agent_metadata_bytes=encoder.encode(agent_metadata),
|
||||
)
|
||||
|
||||
def register_local_xfer_handler(
|
||||
self,
|
||||
@ -1402,8 +1550,6 @@ class NixlConnectorWorker:
|
||||
remote_engine_id = nixl_agent_meta.engine_id
|
||||
|
||||
assert self._tp_size[remote_engine_id] == remote_tp_size
|
||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user