mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:24:57 +08:00
[P/D] NixlConnector DP fixes (#18903)
Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
parent
d1d61f3351
commit
64eaf5fe05
@ -70,7 +70,8 @@ class KVConnectorFactory:
|
|||||||
connector_module = importlib.import_module(connector_module_path)
|
connector_module = importlib.import_module(connector_module_path)
|
||||||
connector_cls = getattr(connector_module, connector_name)
|
connector_cls = getattr(connector_module, connector_name)
|
||||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||||
logger.info("Creating v1 connector with name: %s", connector_name)
|
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||||
|
connector_name, kv_transfer_config.engine_id)
|
||||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||||
# Scheduler connector:
|
# Scheduler connector:
|
||||||
# - Co-locate with scheduler process
|
# - Co-locate with scheduler process
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||||
get_tp_group)
|
get_tp_group, get_world_group)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -334,6 +334,7 @@ class NixlConnectorWorker:
|
|||||||
self.engine_id = engine_id
|
self.engine_id = engine_id
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
self.rank = get_tensor_model_parallel_rank()
|
||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.world_rank = get_world_group().rank_in_group
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
|
|
||||||
# KV Caches and nixl tracking data.
|
# KV Caches and nixl tracking data.
|
||||||
@ -382,7 +383,8 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
ready_event: threading.Event, rank: int):
|
ready_event: threading.Event,
|
||||||
|
world_rank: int):
|
||||||
"""Background thread for getting new NIXL handshakes."""
|
"""Background thread for getting new NIXL handshakes."""
|
||||||
# NOTE(rob): this is a simple implementation. We will move
|
# NOTE(rob): this is a simple implementation. We will move
|
||||||
# to a better approach like an ETCD server in the future.
|
# to a better approach like an ETCD server in the future.
|
||||||
@ -403,7 +405,7 @@ class NixlConnectorWorker:
|
|||||||
# NOTE(rob): we need each rank to have a unique port. This
|
# NOTE(rob): we need each rank to have a unique port. This
|
||||||
# hack to keeps us moving. We will switch when moving to etcd
|
# hack to keeps us moving. We will switch when moving to etcd
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
|
||||||
path = make_zmq_path("tcp", host, port)
|
path = make_zmq_path("tcp", host, port)
|
||||||
logger.debug("Starting listening on path: %s", path)
|
logger.debug("Starting listening on path: %s", path)
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
@ -422,7 +424,7 @@ class NixlConnectorWorker:
|
|||||||
# NOTE(rob): we need each rank to have a unique port. This is
|
# NOTE(rob): we need each rank to have a unique port. This is
|
||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# a hack to keep us moving. We will switch when moving to etcd
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# or where we have a single ZMQ socket in the scheduler.
|
||||||
path = make_zmq_path("tcp", host, port + self.rank)
|
path = make_zmq_path("tcp", host, port + self.world_rank)
|
||||||
logger.debug("Querying metadata on path: %s", path)
|
logger.debug("Querying metadata on path: %s", path)
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
@ -529,7 +531,7 @@ class NixlConnectorWorker:
|
|||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self._nixl_handshake_listener_t = threading.Thread(
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
target=self._nixl_handshake_listener,
|
target=self._nixl_handshake_listener,
|
||||||
args=(metadata, ready_event, self.rank),
|
args=(metadata, ready_event, self.world_rank),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name="nixl_handshake_listener")
|
name="nixl_handshake_listener")
|
||||||
self._nixl_handshake_listener_t.start()
|
self._nixl_handshake_listener_t.start()
|
||||||
|
|||||||
@ -707,6 +707,15 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
assert dp_size > 1
|
assert dp_size > 1
|
||||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||||
|
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
# modify the engine_id and append the local_dp_rank to it to ensure
|
||||||
|
# that the kv_transfer_config is unique for each DP rank.
|
||||||
|
vllm_config.kv_transfer_config.engine_id = (
|
||||||
|
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
|
||||||
|
)
|
||||||
|
logger.debug("Setting kv_transfer_config.engine_id to %s",
|
||||||
|
vllm_config.kv_transfer_config.engine_id)
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
device_control_env_var = current_platform.device_control_env_var
|
device_control_env_var = current_platform.device_control_env_var
|
||||||
world_size = vllm_config.parallel_config.world_size
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user