mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 16:27:02 +08:00
[Bugfix][Nixl] Fix DP Metadata Handshake (#19008)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
d6fd3a33b8
commit
b9f61e1387
@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
get_tp_group, get_world_group)
|
get_tp_group)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -172,6 +172,11 @@ class NixlConnectorScheduler:
|
|||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.engine_id = engine_id
|
self.engine_id = engine_id
|
||||||
|
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
|
self.side_channel_port = (
|
||||||
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
|
vllm_config.parallel_config.data_parallel_rank_local *
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
# Requests that need to start recv.
|
# Requests that need to start recv.
|
||||||
@ -310,8 +315,8 @@ class NixlConnectorScheduler:
|
|||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
remote_block_ids=computed_block_ids,
|
remote_block_ids=computed_block_ids,
|
||||||
remote_engine_id=self.engine_id,
|
remote_engine_id=self.engine_id,
|
||||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
remote_host=self.side_channel_host,
|
||||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
remote_port=self.side_channel_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -330,11 +335,19 @@ class NixlConnectorWorker:
|
|||||||
# Map of engine_id -> agent_name.
|
# Map of engine_id -> agent_name.
|
||||||
self._remote_agents: dict[str, str] = {}
|
self._remote_agents: dict[str, str] = {}
|
||||||
|
|
||||||
|
# NIXL handshake port.
|
||||||
|
# NOTE(rob): Within a DP group, each DP rank gets its own
|
||||||
|
# base port (which is sent in the KVTransferParams).
|
||||||
|
# Each TP rank listens/queries on the base_port + tp_rank.
|
||||||
|
self.side_channel_port = (
|
||||||
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
|
||||||
|
vllm_config.parallel_config.data_parallel_rank_local *
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.engine_id = engine_id
|
self.engine_id = engine_id
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.world_rank = get_world_group().rank_in_group
|
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
# KV Caches and nixl tracking data.
|
# KV Caches and nixl tracking data.
|
||||||
@ -383,16 +396,11 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
ready_event: threading.Event,
|
ready_event: threading.Event, base_port: int,
|
||||||
world_rank: int):
|
tp_rank: int):
|
||||||
"""Background thread for getting new NIXL handshakes."""
|
"""Background thread for getting new NIXL handshakes."""
|
||||||
# NOTE(rob): this is a simple implementation. We will move
|
# NOTE(rob): this is a simple implementation. We will move
|
||||||
# to a better approach like an ETCD server in the future.
|
# to a better approach via HTTP endpoint soon.
|
||||||
|
|
||||||
# NOTE(rob): to support heterogeneous TP, we will have to
|
|
||||||
# move this into the scheduler rather than worker, since
|
|
||||||
# each rank needs the metadata of all other ranks (whereas
|
|
||||||
# in this setup, each rank only gets one other rank's meta.
|
|
||||||
|
|
||||||
encoder = msgspec.msgpack.Encoder()
|
encoder = msgspec.msgpack.Encoder()
|
||||||
encoded_data = encoder.encode(metadata)
|
encoded_data = encoder.encode(metadata)
|
||||||
@ -402,11 +410,7 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Listen for new requests for metadata.
|
# Listen for new requests for metadata.
|
||||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
# NOTE(rob): we need each rank to have a unique port. This
|
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
||||||
# hack to keeps us moving. We will switch when moving to etcd
|
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
|
||||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
|
|
||||||
path = make_zmq_path("tcp", host, port)
|
|
||||||
logger.debug("Starting listening on path: %s", path)
|
logger.debug("Starting listening on path: %s", path)
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
ready_event.set()
|
ready_event.set()
|
||||||
@ -421,10 +425,10 @@ class NixlConnectorWorker:
|
|||||||
"""Do a NIXL handshake with a remote instance."""
|
"""Do a NIXL handshake with a remote instance."""
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
# NOTE(rob): we need each rank to have a unique port. This is
|
# NOTE(rob): we need each tp_rank to have a unique port.
|
||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# This is a hack to keep us moving. We will switch when
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# we switch to HTTP-based NIXL metadata exchange.
|
||||||
path = make_zmq_path("tcp", host, port + self.world_rank)
|
path = make_zmq_path("tcp", host, port + self.tp_rank)
|
||||||
logger.debug("Querying metadata on path: %s", path)
|
logger.debug("Querying metadata on path: %s", path)
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
@ -532,7 +536,7 @@ class NixlConnectorWorker:
|
|||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self._nixl_handshake_listener_t = threading.Thread(
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
target=self._nixl_handshake_listener,
|
target=self._nixl_handshake_listener,
|
||||||
args=(metadata, ready_event, self.world_rank),
|
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name="nixl_handshake_listener")
|
name="nixl_handshake_listener")
|
||||||
self._nixl_handshake_listener_t.start()
|
self._nixl_handshake_listener_t.start()
|
||||||
@ -556,9 +560,9 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * self.block_len
|
block_offset = block_id * self.block_len
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append(
|
blocks_data.append(
|
||||||
(base_addr + block_offset, self.block_len, self.rank))
|
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
|
||||||
len(blocks_data), self.engine_id, self.rank)
|
len(blocks_data), self.engine_id, self.tp_rank)
|
||||||
|
|
||||||
# Register with NIXL.
|
# Register with NIXL.
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
@ -573,9 +577,9 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * self.block_len
|
block_offset = block_id * self.block_len
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append(
|
blocks_data.append(
|
||||||
(base_addr + block_offset, self.block_len, self.rank))
|
(base_addr + block_offset, self.block_len, self.tp_rank))
|
||||||
logger.debug("Created %s blocks for dst engine %s and rank %s",
|
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
|
||||||
len(blocks_data), engine_id, self.rank)
|
len(blocks_data), engine_id, self.tp_rank)
|
||||||
|
|
||||||
# Register with NIXL.
|
# Register with NIXL.
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
@ -600,14 +604,14 @@ class NixlConnectorWorker:
|
|||||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Rank %s, get_finished: %s requests done sending "
|
"Rank %s, get_finished: %s requests done sending "
|
||||||
"and %s requests done recving", self.rank, len(done_sending),
|
"and %s requests done recving", self.tp_rank,
|
||||||
len(done_recving))
|
len(done_sending), len(done_recving))
|
||||||
|
|
||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
# Rank 0: get finished from all other ranks.
|
# Rank 0: get finished from all other ranks.
|
||||||
if self.rank == 0:
|
if self.tp_rank == 0:
|
||||||
for req_id in done_sending:
|
for req_id in done_sending:
|
||||||
self._done_sending_count[req_id] += 1
|
self._done_sending_count[req_id] += 1
|
||||||
for req_id in done_recving:
|
for req_id in done_recving:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user