mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[P/D] NixlConnector DP fixes (#18903)
Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
parent
d1d61f3351
commit
64eaf5fe05
@ -70,7 +70,8 @@ class KVConnectorFactory:
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||
logger.info("Creating v1 connector with name: %s", connector_name)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_name, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
get_tp_group, get_world_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -334,6 +334,7 @@ class NixlConnectorWorker:
|
||||
self.engine_id = engine_id
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.world_rank = get_world_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
@ -382,7 +383,8 @@ class NixlConnectorWorker:
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, rank: int):
|
||||
ready_event: threading.Event,
|
||||
world_rank: int):
|
||||
"""Background thread for getting new NIXL handshakes."""
|
||||
# NOTE(rob): this is a simple implementation. We will move
|
||||
# to a better approach like an ETCD server in the future.
|
||||
@ -403,7 +405,7 @@ class NixlConnectorWorker:
|
||||
# NOTE(rob): we need each rank to have a unique port. This
|
||||
# hack to keeps us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
@ -422,7 +424,7 @@ class NixlConnectorWorker:
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
path = make_zmq_path("tcp", host, port + self.rank)
|
||||
path = make_zmq_path("tcp", host, port + self.world_rank)
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Send query for the request.
|
||||
@ -529,7 +531,7 @@ class NixlConnectorWorker:
|
||||
ready_event = threading.Event()
|
||||
self._nixl_handshake_listener_t = threading.Thread(
|
||||
target=self._nixl_handshake_listener,
|
||||
args=(metadata, ready_event, self.rank),
|
||||
args=(metadata, ready_event, self.world_rank),
|
||||
daemon=True,
|
||||
name="nixl_handshake_listener")
|
||||
self._nixl_handshake_listener_t.start()
|
||||
|
||||
@ -707,6 +707,15 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
assert dp_size > 1
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
# modify the engine_id and append the local_dp_rank to it to ensure
|
||||
# that the kv_transfer_config is unique for each DP rank.
|
||||
vllm_config.kv_transfer_config.engine_id = (
|
||||
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
|
||||
)
|
||||
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||
vllm_config.kv_transfer_config.engine_id)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user