[Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338)

Signed-off-by: Guan Luo <gluo@nvidia.com>
Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
GuanLuo 2025-11-01 01:16:00 +08:00 committed by GitHub
parent 7e06c40e63
commit d6517be3cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 321 additions and 95 deletions

View File

@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
- Default: 5600
- **Required for both prefiller and decoder instances**
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
- Used for the initial NIXL handshake between the prefiller and the decoder
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication

View File

@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlAgentMetadata,
NixlConnector,
NixlConnectorMetadata,
NixlConnectorScheduler,
NixlConnectorWorker,
NixlKVConnectorStats,
)
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
# all workers of the instance.
vllm_config = create_vllm_config()
# in case the test runs on non-GPU machine
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would
# gather connector metadata from all workers, the scheduler connector
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
# where the first key is the dp_rank, the second key is the tp_rank.
metadata = {0: prefill_connector.get_handshake_metadata()}
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata(metadata)
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
assert metadata[0] == received_metadata[0]
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -559,6 +647,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
@ -611,6 +700,7 @@ class TestNixlHandshake:
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
# Request-0 times out and is cleared!
assert "0" not in req_to_blocks
# Need to shutdown the background thread to release NIXL side channel port
llm.llm_engine.engine_core.shutdown()
def test_register_kv_caches(dist_init):
@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
scheduler = NixlConnectorScheduler(
vllm_config, vllm_config.kv_transfer_config.engine_id
)
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
nixl_wrapper = worker.nixl_wrapper
with (
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
worker.shutdown()
mock_exec.shutdown.assert_called_with(wait=False)
mock_event.set.assert_called_once()
mock_listener.join.assert_called_once_with(timeout=1.0)
# Same sequence on scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
scheduler.shutdown()
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2

View File

@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
"""
Metadata used for out of band connector handshake between
P/D workers. This needs to serializeable.
"""
pass
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
"""
return None
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
"""
return None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
"""
return None
@classmethod
def build_prom_metrics(
cls,

View File

@ -27,6 +27,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
)
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
@dataclass
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
device_id: int
num_blocks: int
block_lens: list[int]
attn_backend_name: str
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
############################################################
# Worker Side Methods
############################################################
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
def shutdown(self):
if self.connector_worker is not None:
self.connector_worker.shutdown()
if self.connector_scheduler is not None:
self.connector_scheduler.shutdown()
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.
Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata
class NixlConnectorScheduler:
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
self._stop_event = threading.Event()
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
self._nixl_handshake_listener_t.join()
self._nixl_handshake_listener_t = None
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
"""
Set the KV connector handshake metadata for this connector.
Args:
metadata (dict): the handshake metadata to set.
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
for tp_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlAgentMetadata):
raise ValueError(
"NixlConnectorScheduler expects NixlAgentMetadata for "
"handshake metadata."
)
encoded_data[tp_rank] = encoder.encode(rank_metadata)
logger.debug(
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
tp_rank,
str(len(encoded_data[tp_rank])),
)
self._encoded_xfer_handshake_metadata = encoded_data
# Only start the listener when we have metadata to serve.
if self._nixl_handshake_listener_t is None:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
encoded_data,
ready_event,
self._stop_event,
self.side_channel_port,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
@staticmethod
def _nixl_handshake_listener(
encoded_data: dict[int, Any],
ready_event: threading.Event,
stop_event: threading.Event,
port: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
sock.setsockopt(zmq.RCVTIMEO, 1000)
ready_event.set()
while True:
try:
identity, _, msg = sock.recv_multipart()
except zmq.Again:
if stop_event.is_set():
break
continue
# Decode the message which contains (GET_META_MSG, rank)
msg, target_tp_rank = msgspec.msgpack.decode(msg)
logger.debug(
"Received message for tp rank %s",
target_tp_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
@dataclass
class TpKVTopology:
"""
@ -651,16 +761,6 @@ class NixlConnectorWorker:
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port: int = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
# Metadata.
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
@ -706,6 +806,7 @@ class NixlConnectorWorker:
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@ -736,9 +837,8 @@ class NixlConnectorWorker:
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: threading.Thread | None = None
self._nixl_handshake_listener_stop_event: threading.Event | None = None
# Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
# Background thread for initializing new NIXL handshakes.
self._handshake_initiation_executor = ThreadPoolExecutor(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
@ -790,42 +890,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
)
@staticmethod
def _nixl_handshake_listener(
metadata: NixlAgentMetadata,
ready_event: threading.Event,
stop_event: threading.Event,
base_port: int,
tp_rank: int,
):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach via HTTP endpoint soon.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
path = make_zmq_path("tcp", host, base_port + tp_rank)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)
while not stop_event.is_set():
events = dict(
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
)
if sock not in events:
continue
identity, _, msg = sock.recv_multipart()
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(
self,
host: str,
@ -844,16 +908,17 @@ class NixlConnectorWorker:
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
path = make_zmq_path("tcp", host, port + p_remote_rank)
path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(GET_META_MSG)
sock.send(msg)
metadata_bytes = sock.recv()
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
metadata = decoder.decode(metadata_bytes)
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
# Need to make sure the device ID is non-negative for NIXL,
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
# memory type.
self.device_id = max(cache.get_device(), 0)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "")
)
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
self.xfer_handshake_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
)
ready_event, stop_event = threading.Event(), threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(
metadata,
ready_event,
stop_event,
self.side_channel_port,
self.tp_rank,
),
daemon=True,
name="nixl_handshake_listener",
)
self._nixl_handshake_listener_t.start()
self._nixl_handshake_listener_stop_event = stop_event
ready_event.wait() # Wait for listener ZMQ socket to be ready.
def add_remote_agent(
self,
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, remote_tp_rank))
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id)
)
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
def shutdown(self):
"""Shutdown the connector worker."""
self._handshake_initiation_executor.shutdown(wait=False)
if self._nixl_handshake_listener_stop_event is not None:
self._nixl_handshake_listener_stop_event.set()
self._nixl_handshake_listener_stop_event = None
if self._nixl_handshake_listener_t is not None:
# Generous timeout to allow the thread to exit
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
assert not self._nixl_handshake_listener_t.is_alive()
self._nixl_handshake_listener_t = None
for handles in self._recving_transfers.values():
for handle, _ in handles:
self.nixl_wrapper.release_xfer_handle(handle)

View File

@ -163,6 +163,27 @@ class EngineCore:
vllm_config, mm_registry
)
# If a KV connector is initialized for scheduler, we want to collect
# handshake metadata from all workers so the connector in the scheduler
# will have the full context
kv_connector = self.scheduler.get_kv_connector()
if kv_connector is not None:
# Collect and store KV connector xfer metadata from workers
# (after KV cache registration)
xfer_handshake_metadata = (
self.model_executor.get_kv_connector_handshake_metadata()
)
if xfer_handshake_metadata:
# xfer_handshake_metadata is list of dicts from workers
# Each dict already has structure {tp_rank: metadata}
# Merge all worker dicts into a single dict
content: dict[int, Any] = {}
for worker_dict in xfer_handshake_metadata:
if worker_dict is not None:
content.update(worker_dict)
kv_connector.set_xfer_handshake_metadata(content)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
@ -178,7 +199,7 @@ class EngineCore:
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if (
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
or kv_connector is not None
):
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo

View File

@ -9,6 +9,9 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
@ -177,6 +180,11 @@ class Executor(ABC):
):
raise NotImplementedError
def get_kv_connector_handshake_metadata(
self,
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
return self.collective_rpc("get_kv_connector_handshake_metadata")
@overload
def execute_model(
self,

View File

@ -19,7 +19,11 @@ from vllm.distributed import (
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
get_pp_group,
get_tp_group,
@ -348,6 +352,21 @@ class Worker(WorkerBase):
return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()