mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 19:17:32 +08:00
[Bugfix][Nixl] Fix DP Metadata Handshake (#19008)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
d6fd3a33b8
commit
b9f61e1387
@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group, get_world_group)
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -172,6 +172,11 @@ class NixlConnectorScheduler:
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv.
|
||||
@ -310,8 +315,8 @@ class NixlConnectorScheduler:
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
|
||||
@ -330,11 +335,19 @@ class NixlConnectorWorker:
|
||||
# Map of engine_id -> agent_name.
|
||||
self._remote_agents: dict[str, str] = {}
|
||||
|
||||
# NIXL handshake port.
|
||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||
# base port (which is sent in the KVTransferParams).
|
||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||
self.side_channel_port = (
|
||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||
vllm_config.parallel_config.data_parallel_rank_local *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
# Metadata.
|
||||
self.engine_id = engine_id
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.world_rank = get_world_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
@ -383,16 +396,11 @@ class NixlConnectorWorker:
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event,
|
||||
world_rank: int):
|
||||
ready_event: threading.Event, base_port: int,
|
||||
tp_rank: int):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach like an ETCD server in the future.
|
||||
|
||||
# NOTE(rob): to support heterogeneous TP, we will have to
|
||||
# move this into the scheduler rather than worker, since
|
||||
# each rank needs the metadata of all other ranks (whereas
|
||||
# in this setup, each rank only gets one other rank's meta.
|
||||
# to a better approach via HTTP endpoint soon.
|
||||
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
@ -402,11 +410,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||
# NOTE(rob): we need each rank to have a unique port. This
|
||||
# hack to keeps us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
@ -421,10 +425,10 @@ class NixlConnectorWorker:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
path = make_zmq_path("tcp", host, port + self.world_rank)
|
||||
# NOTE(rob): we need each tp_rank to have a unique port.
|
||||
# This is a hack to keep us moving. We will switch when
|
||||
# we switch to HTTP-based NIXL metadata exchange.
|
||||
path = make_zmq_path("tcp", host, port + self.tp_rank)
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Send query for the request.
|
||||
@ -532,7 +536,7 @@ class NixlConnectorWorker:
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.world_rank),
|
||||
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
@ -556,9 +560,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.rank)
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
@ -573,9 +577,9 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
# (addr, len, device id)
|
||||
blocks_data.append(
|
||||
(base_addr + block_offset, self.block_len, self.rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and rank %s",
|
||||
len(blocks_data), engine_id, self.rank)
|
||||
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
|
||||
len(blocks_data), engine_id, self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
@ -600,14 +604,14 @@ class NixlConnectorWorker:
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
"and %s requests done recving", self.rank, len(done_sending),
|
||||
len(done_recving))
|
||||
"and %s requests done recving", self.tp_rank,
|
||||
len(done_sending), len(done_recving))
|
||||
|
||||
if self.world_size == 1:
|
||||
return done_sending, done_recving
|
||||
|
||||
# Rank 0: get finished from all other ranks.
|
||||
if self.rank == 0:
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self._done_sending_count[req_id] += 1
|
||||
for req_id in done_recving:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user