[V1] Support MP Executor for multi node distributed inference (#23691)

Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Lucia Fang 2025-11-16 01:01:21 -08:00 committed by GitHub
parent a55b64635c
commit b316ac6589
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 930 additions and 82 deletions

View File

@ -0,0 +1,437 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Integration tests for MultiprocExecutor at the executor level.
This test directly tests the executor without going through the LLM interface,
focusing on executor initialization, RPC calls, and distributed execution.
"""
import multiprocessing
import os
from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import get_open_port
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
MODEL = "facebook/opt-125m"
def create_vllm_config(
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
max_model_len: int = 256,
gpu_memory_utilization: float = 0.3,
distributed_executor_backend: str = "mp",
nnodes: int = 1,
node_rank: int = 0,
master_port: int = 0,
) -> VllmConfig:
"""Create a VllmConfig for testing using EngineArgs."""
engine_args = EngineArgs(
model=MODEL,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
# Override distributed node settings if needed
if nnodes > 1 or node_rank > 0:
vllm_config.parallel_config.nnodes = nnodes
vllm_config.parallel_config.node_rank = node_rank
vllm_config.parallel_config.master_port = master_port
if nnodes > 1:
vllm_config.parallel_config.disable_custom_all_reduce = True
return vllm_config
def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
"""Create a minimal SchedulerOutput for testing."""
# This is a simplified version - in practice you'd need proper
# SchedulerOutput construction based on the actual vLLM v1 API
return SchedulerOutput(
scheduled_new_reqs=[],
scheduled_resumed_reqs=[],
scheduled_running_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
)
def test_multiproc_executor_initialization():
"""Test that MultiprocExecutor can be initialized with proper config."""
vllm_config = create_vllm_config(
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
# Create executor - this should initialize workers
executor = MultiprocExecutor(vllm_config=vllm_config)
# Verify executor properties
assert executor.world_size == 1, "World size should be 1 for single GPU"
assert executor.local_world_size == 1, "Local world size should be 1"
assert hasattr(executor, "workers"), "Executor should have workers"
assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
# Clean up
executor.shutdown()
@multi_gpu_test(num_gpus=2)
def test_multiproc_executor_initialization_tensor_parallel():
"""Test MultiprocExecutor initialization with tensor parallelism."""
vllm_config = create_vllm_config(
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
# Create executor
executor = MultiprocExecutor(vllm_config=vllm_config)
# Verify executor properties
assert executor.world_size == 2, "World size should be 2 for TP=2"
assert executor.local_world_size == 2, "Local world size should be 2"
assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
# Verify output rank calculation
output_rank = executor._get_output_rank()
assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
# Clean up
executor.shutdown()
@multi_gpu_test(num_gpus=2)
def test_multiproc_executor_collective_rpc():
"""Test collective RPC calls to all workers."""
vllm_config = create_vllm_config(
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
# Create executor
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Test check_health RPC - should work without errors
executor.check_health()
# Test that RPC works correctly
# Note: We're just testing that the RPC mechanism works,
# not testing actual model execution here
assert not executor.is_failed, "Executor should not be in failed state"
finally:
# Clean up
executor.shutdown()
def test_multiproc_executor_failure_callback():
"""Test failure callback registration and invocation."""
vllm_config = create_vllm_config(
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Test callback registration
callback_invoked = []
def test_callback():
callback_invoked.append(True)
# Register callback
executor.register_failure_callback(test_callback)
# Callback should not be invoked yet
assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
# Simulate failure
executor.is_failed = True
# Register another callback - should be invoked immediately
executor.register_failure_callback(test_callback)
assert len(callback_invoked) == 1, (
"Callback should be invoked when executor is failed"
)
finally:
# Clean up
executor.shutdown()
@multi_gpu_test(num_gpus=2)
def test_multiproc_executor_worker_monitor():
"""Test that worker monitor is set up correctly."""
vllm_config = create_vllm_config(
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Verify all worker processes are alive
for worker in executor.workers:
assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
# Verify executor is not in failed state
assert not executor.is_failed, "Executor should not be in failed state"
finally:
# Clean up
executor.shutdown()
# After shutdown, workers should be terminated
import time
time.sleep(0.5) # Give processes time to terminate
for worker in executor.workers:
assert not worker.proc.is_alive(), (
f"Worker rank {worker.rank} should terminate after shutdown"
)
@multi_gpu_test(num_gpus=2)
def test_multiproc_executor_get_response_message_queues():
"""Test message queue retrieval for different ranks."""
vllm_config = create_vllm_config(
tensor_parallel_size=2,
pipeline_parallel_size=1,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Get all message queues
all_queues = executor.get_response_mqs()
assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
# Get message queue for specific rank
rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
finally:
# Clean up
executor.shutdown()
def test_multiproc_executor_shutdown_cleanup():
"""Test that shutdown properly cleans up resources."""
vllm_config = create_vllm_config(
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
# Verify executor is set up
assert hasattr(executor, "workers"), "Executor should have workers"
assert len(executor.workers) > 0, "Should have at least one worker"
# Shutdown
executor.shutdown()
# Verify cleanup
import time
time.sleep(0.5) # Give processes time to terminate
for worker in executor.workers:
assert not worker.proc.is_alive(), "Worker processes should be terminated"
# Verify shutdown event is set
assert executor.shutdown_event.is_set(), "Shutdown event should be set"
# Multiple shutdowns should be safe (idempotent)
executor.shutdown()
executor.shutdown()
@multi_gpu_test(num_gpus=4)
def test_multiproc_executor_pipeline_parallel():
"""Test MultiprocExecutor with pipeline parallelism."""
vllm_config = create_vllm_config(
tensor_parallel_size=2,
pipeline_parallel_size=2,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Verify executor properties
assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
assert len(executor.workers) == 4, "Should have 4 workers"
# Verify output rank calculation
# For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
# Specifically rank 2 (first rank of last PP stage)
output_rank = executor._get_output_rank()
assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
# Verify max_concurrent_batches for pipeline parallel
assert executor.max_concurrent_batches == 2, (
"Max concurrent batches should equal PP size"
)
finally:
# Clean up
executor.shutdown()
def test_multiproc_executor_properties():
"""Test various executor properties and configurations."""
vllm_config = create_vllm_config(
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
executor = MultiprocExecutor(vllm_config=vllm_config)
try:
# Test supports_pp property
assert MultiprocExecutor.supports_pp is True, (
"MultiprocExecutor should support pipeline parallelism"
)
# Test world_size calculation
assert executor.world_size == (
executor.parallel_config.tensor_parallel_size
* executor.parallel_config.pipeline_parallel_size
), "World size should equal TP * PP"
# Test local_world_size calculation
assert executor.local_world_size == (
executor.parallel_config.world_size // executor.parallel_config.nnodes
), "Local world size should be world_size / nnodes"
finally:
# Clean up
executor.shutdown()
@multi_gpu_test(num_gpus=4)
def test_multiproc_executor_multi_node():
"""
Test MultiprocExecutor with multi-node configuration.
This simulates 2 nodes with TP=4:
- Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
- Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
Total world_size = 4, nnodes = 2
"""
port = get_open_port()
# symm_mem does not work for simulating multi instance in single node
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
"""Run a single node's executor."""
executor = None
try:
# Set CUDA_VISIBLE_DEVICES for this node
if node_rank == 0:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
# Create config for this node
vllm_config = create_vllm_config(
tensor_parallel_size=4, # Total TP across all nodes
pipeline_parallel_size=1,
nnodes=2, # 2 nodes
node_rank=node_rank,
master_port=port, # same port
)
# Create executor for this node
executor = MultiprocExecutor(vllm_config=vllm_config)
# Verify node-specific properties
assert executor.world_size == 4, (
f"World size should be 4 on node {node_rank}"
)
assert executor.local_world_size == 2, (
f"Local world size should be 2 on node {node_rank}"
)
assert len(executor.workers) == 2, (
f"Should have 2 local workers on node {node_rank}"
)
# Verify worker ranks are correct for this node
expected_ranks = [node_rank * 2, node_rank * 2 + 1]
actual_ranks = sorted([w.rank for w in executor.workers])
assert actual_ranks == expected_ranks, (
f"Node {node_rank} should have workers "
f"with ranks {expected_ranks}, got {actual_ranks}"
)
# Verify all workers are alive
for worker in executor.workers:
assert worker.proc.is_alive(), (
f"Worker rank {worker.rank} should be alive on node {node_rank}"
)
# executor.gen
# Put success result in queue BEFORE shutdown to avoid hanging
result_queue.put({"node": node_rank, "success": True})
import time
time.sleep(2)
executor.shutdown()
except Exception as e:
# Put failure result in queue
result_queue.put({"node": node_rank, "success": False, "error": str(e)})
raise e
finally:
if executor is not None:
executor.shutdown()
# Create a queue to collect results from both processes
result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
# Start both node processes
processes = []
for node_rank in range(2):
p = multiprocessing.Process(
target=run_node,
args=(node_rank, result_queue, port),
name=f"Node{node_rank}",
)
p.start()
processes.append(p)
# Wait for both processes to complete
all_completed = True
for p in processes:
p.join(timeout=60)
if p.is_alive():
p.terminate()
p.join(timeout=20)
if p.is_alive():
p.kill()
p.join()
all_completed = False
# Check results from both nodes
results: list[dict[str, int | bool]] = []
while len(results) < 2:
try:
result = result_queue.get(timeout=1)
results.append(result)
except Exception:
pass
assert all_completed, "Not all processes completed successfully"
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
assert results[0]["success"], f"Node 0 failed: {results[0]}"
assert results[1]["success"], f"Node 1 failed: {results[1]}"

View File

@ -210,6 +210,18 @@ class ParallelConfig:
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
master_addr: str = "127.0.0.1"
"""distributed master address for multi-node distributed
inference when distributed_executor_backend is mp."""
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
@ -387,6 +399,23 @@ class ParallelConfig:
and self.data_parallel_size > 1
)
@property
def node_rank_within_dp(self) -> int:
return self.node_rank % self.nnodes_within_dp
@property
def nnodes_within_dp(self) -> int:
if self.nnodes == 1:
return 1
data_parallel_node_size = (
self.data_parallel_size // self.data_parallel_size_local
)
return self.nnodes // data_parallel_node_size
@property
def local_world_size(self) -> int:
return self.world_size // self.nnodes_within_dp
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
@ -528,6 +557,8 @@ class ParallelConfig:
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif current_platform.is_cuda() and self.nnodes > 1:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
@ -565,6 +596,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
raise ValueError(
"nnodes > 1 can only be set when distributed exectuor backend is mp."
)
@property
def use_ray(self) -> bool:
@ -607,6 +642,11 @@ class ParallelConfig:
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
if self.nnodes > 1:
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce since we are running on multi-node."
)
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
import torch
@ -602,13 +602,87 @@ class MessageQueue:
return obj
return self.dequeue()
@staticmethod
def create_from_process_group_single_reader(
pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
reader_rank: int = 0,
blocking: bool = False,
) -> tuple["MessageQueue", list[Handle]]:
"""
Creates a MessageQueue for a process group with a single reader.
This method is designed for scenarios where only one process (the reader)
will consume messages, and all other processes are writers. It sets up
the shared memory buffer and communication handles accordingly, and
gathers the handles from all processes to the reader.
Args:
pg (ProcessGroup): The torch distributed process group.
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
max_chunks (int): Maximum number of chunks in the buffer.
reader_rank (int, optional): The global rank that will act as the reader.
Defaults to 0.
blocking (bool, optional): If True, blocks until all processes are ready.
Defaults to False.
Returns:
tuple[MessageQueue, list[Handle]]:
The MessageQueue instance for the calling process,
and a list of handles (only non-empty for the reader process).
"""
local_size = torch.cuda.device_count()
rank = dist.get_rank()
same_node = rank // local_size == reader_rank // local_size
buffer_io = MessageQueue(
n_reader=1,
n_local_reader=1 if same_node else 0,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
if blocking:
buffer_io.wait_until_ready()
return buffer_io, cast(list[Handle], handles or [])
@staticmethod
def create_from_process_group(
pg: ProcessGroup | StatelessProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0,
writer_rank: int = 0,
external_writer_handle=None,
blocking: bool = True,
) -> "MessageQueue":
"""
Creates a MessageQueue for a distributed process group with one writer and
multiple readers.
This method is designed for scenarios where one process (the writer) sends
messages, and all other processes (the readers) receive messages. It sets up
the shared memory buffer and socket communication handles accordingly, and
broadcasts the handle from the writer to all readers.
Args:
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
group.
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
max_chunks (int): Maximum number of chunks in the buffer.
writer_rank (int, optional): The global rank that will act as the writer.
Defaults to 0.
external_writer_handle (Handle, optional): Used when there is a handle
from an external Message Queue. If provided, use this handle to init
PG writer message queue instead of creating a new one. Defaults to None.
blocking (bool, optional): If True, blocks until all processes are ready.
Defaults to True.
Returns:
MessageQueue: The MessageQueue instance for the calling process.
"""
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
@ -617,23 +691,26 @@ class MessageQueue:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
if external_writer_handle is not None:
buffer_io = MessageQueue.create_from_handle(
external_writer_handle, group_rank
)
else:
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list(
@ -651,5 +728,6 @@ class MessageQueue:
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
if blocking:
buffer_io.wait_until_ready()
return buffer_io

View File

@ -385,6 +385,33 @@ class GroupCoordinator:
torch.ops._C, "init_shm_manager"
)
def create_mq_broadcaster(
self, writer_rank=0, external_writer_handle=None, blocking=True
):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
return MessageQueue.create_from_process_group(
self.cpu_group,
1 << 22,
6,
writer_rank=writer_rank,
external_writer_handle=external_writer_handle,
blocking=blocking,
)
def create_single_reader_mq_broadcasters(
self, reader_rank_in_group=0, blocking=False
):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
return MessageQueue.create_from_process_group_single_reader(
self.cpu_group,
1 << 22,
6,
reader_rank=self.ranks[reader_rank_in_group],
blocking=blocking,
)
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
@ -997,6 +1024,7 @@ class GroupCoordinator:
_WORLD: GroupCoordinator | None = None
_INNER_DP_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None
@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
return _WORLD
def get_inner_dp_world_group() -> GroupCoordinator:
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
return _INNER_DP_WORLD
def init_world_group(
ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
@ -1023,12 +1056,13 @@ def init_model_parallel_group(
backend: str,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
use_device_communicator: bool = True,
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=True,
use_device_communicator=use_device_communicator,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
@ -1143,7 +1177,14 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if (
if config is not None and config.parallel_config.nnodes > 1:
parallel_config = config.parallel_config
ip = parallel_config.master_addr
rank = parallel_config.data_parallel_rank * world_size + rank
world_size = parallel_config.world_size_across_dp
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
elif (
config is not None
and config.parallel_config.data_parallel_size > 1
and config.parallel_config.distributed_executor_backend != "external_launcher"
@ -1164,6 +1205,14 @@ def init_distributed_environment(
distributed_init_method,
)
if not torch.distributed.is_initialized():
logger.info(
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
@ -1192,16 +1241,36 @@ def init_distributed_environment(
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_NODE_COUNT = _node_count(_WORLD.cpu_group)
if config.parallel_config.nnodes > 1:
_NODE_COUNT = config.parallel_config.nnodes
else:
_NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size"
)
if config.parallel_config.nnodes_within_dp > 1:
if parallel_config.data_parallel_size > 1:
world_size_inner_dp = parallel_config.world_size
group_ranks = [
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
for dp_rank in range(parallel_config.data_parallel_size)
]
_INNER_DP_WORLD = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="inner_dp_world",
use_device_communicator=False,
)
else:
_INNER_DP_WORLD = _WORLD
def initialize_model_parallel(

View File

@ -384,6 +384,10 @@ class EngineArgs:
) = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
master_addr: str = ParallelConfig.master_addr
master_port: int = ParallelConfig.master_port
nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
@ -394,6 +398,7 @@ class EngineArgs:
data_parallel_address: str | None = None
data_parallel_rpc_port: int | None = None
data_parallel_hybrid_lb: bool = False
data_parallel_external_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
@ -749,6 +754,10 @@ class EngineArgs:
"-pp",
**parallel_kwargs["pipeline_parallel_size"],
)
parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
parallel_group.add_argument(
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
)
@ -803,7 +812,14 @@ class EngineArgs:
help='Backend for data parallel, either "mp" or "ray".',
)
parallel_group.add_argument(
"--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
"--data-parallel-hybrid-lb",
"-dph",
**parallel_kwargs["data_parallel_hybrid_lb"],
)
parallel_group.add_argument(
"--data-parallel-external-lb",
"-dpe",
**parallel_kwargs["data_parallel_external_lb"],
)
parallel_group.add_argument(
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
@ -1428,12 +1444,56 @@ class EngineArgs:
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in headless mode"
)
data_parallel_external_lb = self.data_parallel_rank is not None
assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
"data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
)
assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
"nnodes > 1 is only supported with data_parallel_backend=mp"
)
inferred_data_parallel_rank = 0
if self.nnodes > 1:
world_size = (
self.data_parallel_size
* self.pipeline_parallel_size
* self.tensor_parallel_size
)
world_size_within_dp = (
self.pipeline_parallel_size * self.tensor_parallel_size
)
local_world_size = world_size // self.nnodes
assert world_size % self.nnodes == 0, (
f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
)
assert self.node_rank < self.nnodes, (
f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
)
inferred_data_parallel_rank = (
self.node_rank * local_world_size
) // world_size_within_dp
if self.data_parallel_size > 1 and self.data_parallel_external_lb:
self.data_parallel_rank = inferred_data_parallel_rank
logger.info(
"Inferred data_parallel_rank %d from node_rank %d for external lb",
self.data_parallel_rank,
self.node_rank,
)
elif self.data_parallel_size_local is None:
# Infer data parallel size local for internal dplb:
self.data_parallel_size_local = max(
local_world_size // world_size_within_dp, 1
)
data_parallel_external_lb = (
self.data_parallel_external_lb or self.data_parallel_rank is not None
)
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_rank is not None, (
"data_parallel_rank or node_rank must be spefified if "
"data_parallel_external_lb is enable."
)
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank is set"
"data_parallel_size_local must be 1 or None when data_parallel_rank "
"is set"
)
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
@ -1447,6 +1507,11 @@ class EngineArgs:
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
logger.warning(
"data_parallel_hybrid_lb is not eligible when "
"data_parallel_size_local = 1, autoswitch to "
"data_parallel_external_lb."
)
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
@ -1454,7 +1519,15 @@ class EngineArgs:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
self.data_parallel_rank = self.data_parallel_start_rank or 0
self.data_parallel_rank = (
self.data_parallel_start_rank or inferred_data_parallel_rank
)
if self.nnodes > 1:
logger.info(
"Inferred data_parallel_rank %d from node_rank %d",
self.data_parallel_rank,
self.node_rank,
)
else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
@ -1484,7 +1557,9 @@ class EngineArgs:
"data_parallel_backend can only be ray or mp, got %s",
self.data_parallel_backend,
)
data_parallel_address = ParallelConfig.data_parallel_master_ip
data_parallel_address = (
self.master_addr or ParallelConfig.data_parallel_master_ip
)
else:
data_parallel_address = self.data_parallel_address
@ -1517,6 +1592,10 @@ class EngineArgs:
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
data_parallel_size_local=data_parallel_size_local,
master_addr=self.master_addr,
master_port=self.master_port,
nnodes=self.nnodes,
node_rank=self.node_rank,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,

View File

@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor import Executor
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace):
if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
shutdown_requested = False
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
nonlocal shutdown_requested
logger.debug("Received %d signal.", signum)
raise SystemExit
if not shutdown_requested:
shutdown_requested = True
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if parallel_config.node_rank_within_dp > 0:
from vllm.version import __version__ as VLLM_VERSION
# Run headless workers (for multi-node PP/TP).
host = parallel_config.master_addr
head_node_address = f"{host}:{parallel_config.master_port}"
logger.info(
"Launching vLLM (v%s) headless multiproc executor, "
"with head node address %s for torch.distributed process group.",
VLLM_VERSION,
head_node_address,
)
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
executor.start_worker_monitor(inline=True)
return
host = parallel_config.data_parallel_master_ip
port = parallel_config.data_parallel_rpc_port
handshake_address = get_tcp_uri(host, port)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.",

View File

@ -183,15 +183,19 @@ def set_device_control_env_var(
for engine subprocess.
"""
world_size = vllm_config.parallel_config.world_size
local_world_size = vllm_config.parallel_config.local_world_size
evar = current_platform.device_control_env_var
value = get_device_indices(evar, local_dp_rank, world_size)
value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
with patch.dict(os.environ, values=((evar, value),)):
yield
def get_device_indices(
device_control_env_var: str, local_dp_rank: int, world_size: int
device_control_env_var: str,
local_dp_rank: int,
world_size: int,
local_world_size: int | None = None,
):
"""
Returns a comma-separated string of device indices for the specified
@ -200,10 +204,15 @@ def get_device_indices(
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
this will select devices 2 and 3 for local_dp_rank=1.
"""
if local_world_size is None:
local_world_size = world_size
try:
value = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
for i in range(
local_dp_rank * world_size,
local_dp_rank * world_size + local_world_size,
)
)
except IndexError as e:
raise Exception(

View File

@ -10,7 +10,7 @@ import time
import traceback
import weakref
from collections import deque
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import Future, InvalidStateError
from contextlib import suppress
from dataclasses import dataclass
@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
get_inner_dp_world_group,
get_pp_group,
get_tp_group,
)
@ -90,6 +91,10 @@ class FutureWrapper(Future):
class MultiprocExecutor(Executor):
supports_pp: bool = True
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
self.monitor_workers = monitor_workers
super().__init__(vllm_config)
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
distributed_init_method = get_distributed_init_method(
get_loopback_ip(), get_open_port()
)
self.rpc_broadcast_mq: MessageQueue | None = None
scheduler_output_handle: Handle | None = None
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
if self.parallel_config.node_rank_within_dp == 0:
# For leader node within each dp rank,
# each dp will have its own leader multiproc executor.
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(
self.world_size,
self.local_world_size,
max_chunk_bytes=max_chunk_bytes,
connect_ip=self.parallel_config.master_addr,
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
context = get_mp_context()
shared_worker_lock = context.Lock()
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
global_start_rank = (
self.local_world_size * self.parallel_config.node_rank_within_dp
)
for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
local_rank=local_rank,
rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
# Wait for all local workers to be ready.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Start background thread to monitor worker health if not in headless mode.
if self.monitor_workers:
self.start_worker_monitor()
self.response_mqs = []
# Only leader node have remote response mqs
if self.parallel_config.node_rank_within_dp == 0:
for rank in range(self.world_size):
if rank < self.local_world_size:
local_message_queue = self.workers[rank].worker_response_mq
assert local_message_queue is not None
self.response_mqs.append(local_message_queue)
else:
remote_message_queue = self.workers[0].peer_worker_response_mqs[
rank
]
assert remote_message_queue is not None
self.response_mqs.append(remote_message_queue)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
# Wait for all input mqs to be ready.
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.wait_until_ready()
# Wait for all remote response mqs to be ready.
for response_mq in self.response_mqs:
response_mq.wait_until_ready()
success = True
finally:
if not success:
@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
def start_worker_monitor(self):
def start_worker_monitor(self, inline=False) -> None:
workers = self.workers
self_ref = weakref.ref(self)
@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
_self.failure_callback = None
callback()
Thread(
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
).start()
if not inline:
Thread(
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
).start()
return
monitor_workers()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""
assert self.rpc_broadcast_mq is not None, (
"collective_rpc should not be called on follower node"
)
if self.is_failed:
raise RuntimeError("Executor failed.")
@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
workers = (
(self.workers[output_rank],) if output_rank is not None else self.workers
)
response_mqs: Sequence[MessageQueue] = self.response_mqs
if output_rank is not None:
response_mqs = (response_mqs[output_rank],)
shutdown_event = self.shutdown_event
def get_response():
responses = []
for w in workers:
for mq in response_mqs:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
try:
status, result = w.worker_response_mq.dequeue(
status, result = mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event
)
except TimeoutError as e:
@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
# The worker process writes to this MQ in single-node mode
worker_response_mq: MessageQueue | None
# This is only non empty on driver node,
# the peer worker process i writes to MQ
# `peer_worker_response_mqs[i]`
peer_worker_response_mqs: list[MessageQueue | None]
death_writer: Connection | None = None
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
cls,
unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue | None,
peer_worker_response_mqs: list[MessageQueue | None],
) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
peer_worker_response_mqs=peer_worker_response_mqs,
death_writer=unready_handle.death_writer,
)
@ -411,6 +470,38 @@ class WorkerProc:
READY_STR = "READY"
def _init_message_queues(
self, input_shm_handle: Handle, vllm_config: VllmConfig
) -> None:
if vllm_config.parallel_config.nnodes_within_dp == 1:
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
self.peer_response_handles = []
else:
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
external_writer_handle=input_shm_handle,
# Since there is external_writer_handle from executor proc,
# where the ready signal from actual writer is sent out of the
# create_mq_broadcaster method and after this setup, we make it
# non blocking. The handshake will be triggered when
# worker.rpc_broadcast_mq.wait_until_ready() is called
blocking=False,
)
# Initializes remote message queue for sending the model output to the
# driver worker, exposing peer_response_handles for driver worker
# that include handles for all ranks
self.worker_response_mq, self.peer_response_handles = (
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
reader_rank_in_group=0
)
)
def __init__(
self,
vllm_config: VllmConfig,
@ -421,13 +512,15 @@ class WorkerProc:
shared_worker_lock: LockType,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
wrapper = WorkerWrapperBase(
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
all_kwargs[rank] = {
all_kwargs[local_rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
@ -438,14 +531,6 @@ class WorkerProc:
wrapper.init_worker(all_kwargs)
self.worker = wrapper
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
@ -466,6 +551,7 @@ class WorkerProc:
)
# Load model
self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
@ -512,6 +598,27 @@ class WorkerProc:
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
@staticmethod
def wait_for_response_handle_ready(
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
) -> WorkerProcHandle:
response_handle = handles["handle"]
worker_response_mq: MessageQueue | None = None
if len(response_handle.local_reader_ranks) > 0:
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
peer_response_handles = handles["peer_response_handles"]
peer_worker_response_mqs = [
MessageQueue.create_from_handle(handle, -1)
if handle.remote_subscribe_addr is not None
else None
for handle in peer_response_handles
]
return WorkerProcHandle.from_unready_handle(
proc_handle,
worker_response_mq,
peer_worker_response_mqs=peer_worker_response_mqs,
)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle],
@ -537,16 +644,10 @@ class WorkerProc:
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0
idx = unready_proc_handle.rank % len(ready_proc_handles)
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
response, unready_proc_handle
)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq
)
)
except EOFError:
e.__suppress_context__ = True
raise e from None
@ -618,12 +719,14 @@ class WorkerProc:
{
"status": WorkerProc.READY_STR,
"handle": worker.worker_response_mq.export_handle(),
"peer_response_handles": worker.peer_response_handles,
}
)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
if worker.rpc_broadcast_mq is not None:
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None

View File

@ -189,6 +189,7 @@ class Worker(WorkerBase):
and self.parallel_config.distributed_executor_backend
not in ["ray", "external_launcher"]
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
and self.vllm_config.parallel_config.nnodes_within_dp == 1
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
@ -205,7 +206,14 @@ class Worker(WorkerBase):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must be "
f"less than or equal to the number of visible devices "
f"({visible_device_count})."
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)

View File

@ -180,6 +180,7 @@ class WorkerWrapperBase:
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
global_rank: int | None = None,
) -> None:
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
@ -192,6 +193,7 @@ class WorkerWrapperBase:
group.
"""
self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None
# do not store this `vllm_config`, `init_worker` will set the final
@ -312,7 +314,7 @@ class WorkerWrapperBase:
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
kv_cache_config = kv_cache_configs[self.global_rank]
with set_current_vllm_config(self.vllm_config):
self.worker.initialize_from_config(kv_cache_config) # type: ignore