mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 14:04:29 +08:00
[Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338)
Signed-off-by: Guan Luo <gluo@nvidia.com> Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
parent
7e06c40e63
commit
d6517be3cd
@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
||||
- Default: 5600
|
||||
- **Required for both prefiller and decoder instances**
|
||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
|
||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
|
||||
- Used for the initial NIXL handshake between the prefiller and the decoder
|
||||
|
||||
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlAgentMetadata,
|
||||
NixlConnector,
|
||||
NixlConnectorMetadata,
|
||||
NixlConnectorScheduler,
|
||||
NixlConnectorWorker,
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_kv_transfer_handshake(dist_init):
|
||||
"""Unit test for basic NixlConnector interface functionality."""
|
||||
|
||||
# Test setup, we creates a scheduler that contains a NixlConnector
|
||||
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
|
||||
# all workers of the instance.
|
||||
vllm_config = create_vllm_config()
|
||||
# in case the test runs on non-GPU machine
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Create two NixlConnector of role WORKER, one is the worker of
|
||||
# the scheduler (prefill), the other is a worker of decode instance.
|
||||
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
prefill_connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Simulate EngineCore initialization that would
|
||||
# gather connector metadata from all workers, the scheduler connector
|
||||
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
|
||||
# where the first key is the dp_rank, the second key is the tp_rank.
|
||||
metadata = {0: prefill_connector.get_handshake_metadata()}
|
||||
scheduler_connector = scheduler.get_kv_connector()
|
||||
scheduler_connector.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
# Simulate a request that finishes prefill, which returns
|
||||
# corresponding NixlConnectorMetadata for decode instance.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||
request, [0, 1, 2]
|
||||
)
|
||||
assert delay
|
||||
|
||||
# Decode connector will be able to create handshake with the prefill connector.
|
||||
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||
# Knowing the implementation detail, we override the add_remote_agent
|
||||
# to validate the metadata received is the same as the one in prefill_connector.
|
||||
with patch.object(
|
||||
decode_connector.connector_worker, "add_remote_agent"
|
||||
) as mock_add_remote_agent:
|
||||
mock_add_remote_agent.return_type = "remote_agent"
|
||||
|
||||
decode_connector.connector_worker._nixl_handshake(
|
||||
kv_connector_metadata["remote_host"],
|
||||
kv_connector_metadata["remote_port"],
|
||||
kv_connector_metadata["tp_size"],
|
||||
kv_connector_metadata["remote_engine_id"],
|
||||
)
|
||||
|
||||
received_metadata = mock_add_remote_agent.call_args.args
|
||||
assert received_metadata[1] == 0 # remote_tp_rank
|
||||
assert received_metadata[2] == 1 # remote_tp_size
|
||||
assert metadata[0] == received_metadata[0]
|
||||
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
scheduler_connector.shutdown()
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
@ -559,6 +647,7 @@ class TestNixlHandshake:
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
@ -611,6 +700,7 @@ class TestNixlHandshake:
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert "0" not in req_to_blocks
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
scheduler = NixlConnectorScheduler(
|
||||
vllm_config, vllm_config.kv_transfer_config.engine_id
|
||||
)
|
||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||
nixl_wrapper = worker.nixl_wrapper
|
||||
|
||||
with (
|
||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
|
||||
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
|
||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||
@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
worker.shutdown()
|
||||
|
||||
mock_exec.shutdown.assert_called_with(wait=False)
|
||||
mock_event.set.assert_called_once()
|
||||
mock_listener.join.assert_called_once_with(timeout=1.0)
|
||||
|
||||
# Same sequence on scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown()
|
||||
mock_listener.join.assert_called_once()
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
|
||||
@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Metadata used for out of band connector handshake between
|
||||
P/D workers. This needs to serializeable.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
|
||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
):
|
||||
@dataclass
|
||||
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
device_id: int
|
||||
num_blocks: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
def shutdown(self):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.shutdown()
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.shutdown()
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
Get the KVConnector handshake metadata for this connector.
|
||||
This metadata is used for out-of-band connector handshake
|
||||
between P/D workers.
|
||||
|
||||
Returns:
|
||||
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||
None if no handshake metadata is available.
|
||||
"""
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.xfer_handshake_metadata
|
||||
|
||||
|
||||
class NixlConnectorScheduler:
|
||||
@ -312,12 +337,16 @@ class NixlConnectorScheduler:
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
@ -330,6 +359,89 @@ class NixlConnectorScheduler:
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
def shutdown(self):
|
||||
self._stop_event.set()
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join()
|
||||
self._nixl_handshake_listener_t = None
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
"""
|
||||
Set the KV connector handshake metadata for this connector.
|
||||
|
||||
Args:
|
||||
metadata (dict): the handshake metadata to set.
|
||||
"""
|
||||
encoded_data: dict[int, bytes] = {}
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
for tp_rank, rank_metadata in metadata.items():
|
||||
if not isinstance(rank_metadata, NixlAgentMetadata):
|
||||
raise ValueError(
|
||||
"NixlConnectorScheduler expects NixlAgentMetadata for "
|
||||
"handshake metadata."
|
||||
)
|
||||
encoded_data[tp_rank] = encoder.encode(rank_metadata)
|
||||
logger.debug(
|
||||
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
|
||||
tp_rank,
|
||||
str(len(encoded_data[tp_rank])),
|
||||
)
|
||||
self._encoded_xfer_handshake_metadata = encoded_data
|
||||
|
||||
# Only start the listener when we have metadata to serve.
|
||||
if self._nixl_handshake_listener_t is None:
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(
|
||||
encoded_data,
|
||||
ready_event,
|
||||
self._stop_event,
|
||||
self.side_channel_port,
|
||||
),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener",
|
||||
)
|
||||
self._nixl_handshake_listener_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(
|
||||
encoded_data: dict[int, Any],
|
||||
ready_event: threading.Event,
|
||||
stop_event: threading.Event,
|
||||
port: int,
|
||||
):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
sock.setsockopt(zmq.RCVTIMEO, 1000)
|
||||
ready_event.set()
|
||||
while True:
|
||||
try:
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
except zmq.Again:
|
||||
if stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
# Decode the message which contains (GET_META_MSG, rank)
|
||||
msg, target_tp_rank = msgspec.msgpack.decode(msg)
|
||||
logger.debug(
|
||||
"Received message for tp rank %s",
|
||||
target_tp_rank,
|
||||
)
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
@ -537,8 +649,6 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
@ -651,16 +761,6 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
# NIXL handshake port.
|
||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||
# base port (which is sent in the KVTransferParams).
|
||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||
self.side_channel_port: int = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||
+ vllm_config.parallel_config.data_parallel_rank
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
# Metadata.
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -706,6 +806,7 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank will still only pull from a single remote TP worker.
|
||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||
self.device_id: int = 0
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
@ -736,9 +837,8 @@ class NixlConnectorWorker:
|
||||
# requests that skipped transfer (handshake or transfer failures)
|
||||
self._failed_recv_reqs: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||
self._nixl_handshake_listener_stop_event: threading.Event | None = None
|
||||
# Handshake metadata of this worker for NIXL transfers.
|
||||
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
|
||||
# Background thread for initializing new NIXL handshakes.
|
||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||
@ -790,42 +890,6 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(
|
||||
metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event,
|
||||
stop_event: threading.Event,
|
||||
base_port: int,
|
||||
tp_rank: int,
|
||||
):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
size_in_bytes = len(encoded_data)
|
||||
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
poller = zmq.Poller()
|
||||
poller.register(sock, zmq.POLLIN)
|
||||
while not stop_event.is_set():
|
||||
events = dict(
|
||||
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
|
||||
)
|
||||
if sock not in events:
|
||||
continue
|
||||
identity, _, msg = sock.recv_multipart()
|
||||
if msg != GET_META_MSG:
|
||||
logger.warning("Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
host: str,
|
||||
@ -844,16 +908,17 @@ class NixlConnectorWorker:
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug(
|
||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
||||
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
|
||||
)
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(GET_META_MSG)
|
||||
sock.send(msg)
|
||||
metadata_bytes = sock.recv()
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
# Need to make sure the device ID is non-negative for NIXL,
|
||||
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
|
||||
# memory type.
|
||||
self.device_id = max(cache.get_device(), 0)
|
||||
caches_data.append(
|
||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||
)
|
||||
@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
self.xfer_handshake_metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
device_id=self.device_id,
|
||||
num_blocks=self.num_blocks,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
)
|
||||
ready_event, stop_event = threading.Event(), threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(
|
||||
metadata,
|
||||
ready_event,
|
||||
stop_event,
|
||||
self.side_channel_port,
|
||||
self.tp_rank,
|
||||
),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener",
|
||||
)
|
||||
self._nixl_handshake_listener_t.start()
|
||||
self._nixl_handshake_listener_stop_event = stop_event
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
def add_remote_agent(
|
||||
self,
|
||||
@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
||||
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
blocks_data.append(
|
||||
(v_addr, kv_block_len, nixl_agent_meta.device_id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||
@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
|
||||
def shutdown(self):
|
||||
"""Shutdown the connector worker."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_stop_event is not None:
|
||||
self._nixl_handshake_listener_stop_event.set()
|
||||
self._nixl_handshake_listener_stop_event = None
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
# Generous timeout to allow the thread to exit
|
||||
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
|
||||
assert not self._nixl_handshake_listener_t.is_alive()
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
|
||||
@ -163,6 +163,27 @@ class EngineCore:
|
||||
vllm_config, mm_registry
|
||||
)
|
||||
|
||||
# If a KV connector is initialized for scheduler, we want to collect
|
||||
# handshake metadata from all workers so the connector in the scheduler
|
||||
# will have the full context
|
||||
kv_connector = self.scheduler.get_kv_connector()
|
||||
if kv_connector is not None:
|
||||
# Collect and store KV connector xfer metadata from workers
|
||||
# (after KV cache registration)
|
||||
xfer_handshake_metadata = (
|
||||
self.model_executor.get_kv_connector_handshake_metadata()
|
||||
)
|
||||
|
||||
if xfer_handshake_metadata:
|
||||
# xfer_handshake_metadata is list of dicts from workers
|
||||
# Each dict already has structure {tp_rank: metadata}
|
||||
# Merge all worker dicts into a single dict
|
||||
content: dict[int, Any] = {}
|
||||
for worker_dict in xfer_handshake_metadata:
|
||||
if worker_dict is not None:
|
||||
content.update(worker_dict)
|
||||
kv_connector.set_xfer_handshake_metadata(content)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
@ -178,7 +199,7 @@ class EngineCore:
|
||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||
if (
|
||||
self.vllm_config.cache_config.enable_prefix_caching
|
||||
or self.scheduler.get_kv_connector() is not None
|
||||
or kv_connector is not None
|
||||
):
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo
|
||||
|
||||
@ -9,6 +9,9 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorHandshakeMetadata,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
@ -177,6 +180,11 @@ class Executor(ABC):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector_handshake_metadata(
|
||||
self,
|
||||
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
|
||||
return self.collective_rpc("get_kv_connector_handshake_metadata")
|
||||
|
||||
@overload
|
||||
def execute_model(
|
||||
self,
|
||||
|
||||
@ -19,7 +19,11 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
@ -348,6 +352,21 @@ class Worker(WorkerBase):
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
||||
"""Get KV connector metadata from this worker if available."""
|
||||
|
||||
if not has_kv_transfer_group():
|
||||
return None
|
||||
|
||||
connector = get_kv_transfer_group()
|
||||
# Return None for connectors that don't need to exchange handshake
|
||||
# metadata across workers.
|
||||
if (metadata := connector.get_handshake_metadata()) is None:
|
||||
return None
|
||||
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
return {tp_rank: metadata}
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user