mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[V1] Support MP Executor for multi node distributed inference (#23691)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a55b64635c
commit
b316ac6589
437
tests/distributed/test_multiproc_executor.py
Normal file
437
tests/distributed/test_multiproc_executor.py
Normal file
@ -0,0 +1,437 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Integration tests for MultiprocExecutor at the executor level.
|
||||
This test directly tests the executor without going through the LLM interface,
|
||||
focusing on executor initialization, RPC calls, and distributed execution.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
max_model_len: int = 256,
|
||||
gpu_memory_utilization: float = 0.3,
|
||||
distributed_executor_backend: str = "mp",
|
||||
nnodes: int = 1,
|
||||
node_rank: int = 0,
|
||||
master_port: int = 0,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing using EngineArgs."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
# Override distributed node settings if needed
|
||||
if nnodes > 1 or node_rank > 0:
|
||||
vllm_config.parallel_config.nnodes = nnodes
|
||||
vllm_config.parallel_config.node_rank = node_rank
|
||||
vllm_config.parallel_config.master_port = master_port
|
||||
if nnodes > 1:
|
||||
vllm_config.parallel_config.disable_custom_all_reduce = True
|
||||
|
||||
return vllm_config
|
||||
|
||||
|
||||
def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput:
|
||||
"""Create a minimal SchedulerOutput for testing."""
|
||||
# This is a simplified version - in practice you'd need proper
|
||||
# SchedulerOutput construction based on the actual vLLM v1 API
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_resumed_reqs=[],
|
||||
scheduled_running_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def test_multiproc_executor_initialization():
|
||||
"""Test that MultiprocExecutor can be initialized with proper config."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor - this should initialize workers
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 1, "World size should be 1 for single GPU"
|
||||
assert executor.local_world_size == 1, "Local world size should be 1"
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) == 1, "Should have 1 worker for single GPU"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_initialization_tensor_parallel():
|
||||
"""Test MultiprocExecutor initialization with tensor parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 2, "World size should be 2 for TP=2"
|
||||
assert executor.local_world_size == 2, "Local world size should be 2"
|
||||
assert len(executor.workers) == 2, "Should have 2 workers for TP=2"
|
||||
|
||||
# Verify output rank calculation
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1"
|
||||
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_collective_rpc():
|
||||
"""Test collective RPC calls to all workers."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
# Create executor
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test check_health RPC - should work without errors
|
||||
executor.check_health()
|
||||
|
||||
# Test that RPC works correctly
|
||||
# Note: We're just testing that the RPC mechanism works,
|
||||
# not testing actual model execution here
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_failure_callback():
|
||||
"""Test failure callback registration and invocation."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test callback registration
|
||||
callback_invoked = []
|
||||
|
||||
def test_callback():
|
||||
callback_invoked.append(True)
|
||||
|
||||
# Register callback
|
||||
executor.register_failure_callback(test_callback)
|
||||
|
||||
# Callback should not be invoked yet
|
||||
assert len(callback_invoked) == 0, "Callback should not be invoked immediately"
|
||||
|
||||
# Simulate failure
|
||||
executor.is_failed = True
|
||||
|
||||
# Register another callback - should be invoked immediately
|
||||
executor.register_failure_callback(test_callback)
|
||||
assert len(callback_invoked) == 1, (
|
||||
"Callback should be invoked when executor is failed"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_worker_monitor():
|
||||
"""Test that worker monitor is set up correctly."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify all worker processes are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive"
|
||||
|
||||
# Verify executor is not in failed state
|
||||
assert not executor.is_failed, "Executor should not be in failed state"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
# After shutdown, workers should be terminated
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should terminate after shutdown"
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_multiproc_executor_get_response_message_queues():
|
||||
"""Test message queue retrieval for different ranks."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Get all message queues
|
||||
all_queues = executor.get_response_mqs()
|
||||
assert len(all_queues) == 2, "Should have 2 message queues for 2 workers"
|
||||
|
||||
# Get message queue for specific rank
|
||||
rank0_queue = executor.get_response_mqs(unique_reply_rank=0)
|
||||
assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0"
|
||||
|
||||
rank1_queue = executor.get_response_mqs(unique_reply_rank=1)
|
||||
assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_shutdown_cleanup():
|
||||
"""Test that shutdown properly cleans up resources."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify executor is set up
|
||||
assert hasattr(executor, "workers"), "Executor should have workers"
|
||||
assert len(executor.workers) > 0, "Should have at least one worker"
|
||||
|
||||
# Shutdown
|
||||
executor.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
import time
|
||||
|
||||
time.sleep(0.5) # Give processes time to terminate
|
||||
|
||||
for worker in executor.workers:
|
||||
assert not worker.proc.is_alive(), "Worker processes should be terminated"
|
||||
|
||||
# Verify shutdown event is set
|
||||
assert executor.shutdown_event.is_set(), "Shutdown event should be set"
|
||||
|
||||
# Multiple shutdowns should be safe (idempotent)
|
||||
executor.shutdown()
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_pipeline_parallel():
|
||||
"""Test MultiprocExecutor with pipeline parallelism."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Verify executor properties
|
||||
assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2"
|
||||
assert len(executor.workers) == 4, "Should have 4 workers"
|
||||
|
||||
# Verify output rank calculation
|
||||
# For TP=2, PP=2: output should be from the last PP stage (ranks 2-3)
|
||||
# Specifically rank 2 (first rank of last PP stage)
|
||||
output_rank = executor._get_output_rank()
|
||||
assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)"
|
||||
|
||||
# Verify max_concurrent_batches for pipeline parallel
|
||||
assert executor.max_concurrent_batches == 2, (
|
||||
"Max concurrent batches should equal PP size"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
def test_multiproc_executor_properties():
|
||||
"""Test various executor properties and configurations."""
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
try:
|
||||
# Test supports_pp property
|
||||
assert MultiprocExecutor.supports_pp is True, (
|
||||
"MultiprocExecutor should support pipeline parallelism"
|
||||
)
|
||||
|
||||
# Test world_size calculation
|
||||
assert executor.world_size == (
|
||||
executor.parallel_config.tensor_parallel_size
|
||||
* executor.parallel_config.pipeline_parallel_size
|
||||
), "World size should equal TP * PP"
|
||||
|
||||
# Test local_world_size calculation
|
||||
assert executor.local_world_size == (
|
||||
executor.parallel_config.world_size // executor.parallel_config.nnodes
|
||||
), "Local world size should be world_size / nnodes"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
def test_multiproc_executor_multi_node():
|
||||
"""
|
||||
Test MultiprocExecutor with multi-node configuration.
|
||||
This simulates 2 nodes with TP=4:
|
||||
- Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2
|
||||
- Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2
|
||||
Total world_size = 4, nnodes = 2
|
||||
"""
|
||||
port = get_open_port()
|
||||
# symm_mem does not work for simulating multi instance in single node
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int):
|
||||
"""Run a single node's executor."""
|
||||
executor = None
|
||||
try:
|
||||
# Set CUDA_VISIBLE_DEVICES for this node
|
||||
if node_rank == 0:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
# Create config for this node
|
||||
vllm_config = create_vllm_config(
|
||||
tensor_parallel_size=4, # Total TP across all nodes
|
||||
pipeline_parallel_size=1,
|
||||
nnodes=2, # 2 nodes
|
||||
node_rank=node_rank,
|
||||
master_port=port, # same port
|
||||
)
|
||||
|
||||
# Create executor for this node
|
||||
executor = MultiprocExecutor(vllm_config=vllm_config)
|
||||
|
||||
# Verify node-specific properties
|
||||
assert executor.world_size == 4, (
|
||||
f"World size should be 4 on node {node_rank}"
|
||||
)
|
||||
assert executor.local_world_size == 2, (
|
||||
f"Local world size should be 2 on node {node_rank}"
|
||||
)
|
||||
assert len(executor.workers) == 2, (
|
||||
f"Should have 2 local workers on node {node_rank}"
|
||||
)
|
||||
|
||||
# Verify worker ranks are correct for this node
|
||||
expected_ranks = [node_rank * 2, node_rank * 2 + 1]
|
||||
actual_ranks = sorted([w.rank for w in executor.workers])
|
||||
assert actual_ranks == expected_ranks, (
|
||||
f"Node {node_rank} should have workers "
|
||||
f"with ranks {expected_ranks}, got {actual_ranks}"
|
||||
)
|
||||
# Verify all workers are alive
|
||||
for worker in executor.workers:
|
||||
assert worker.proc.is_alive(), (
|
||||
f"Worker rank {worker.rank} should be alive on node {node_rank}"
|
||||
)
|
||||
# executor.gen
|
||||
# Put success result in queue BEFORE shutdown to avoid hanging
|
||||
result_queue.put({"node": node_rank, "success": True})
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
executor.shutdown()
|
||||
except Exception as e:
|
||||
# Put failure result in queue
|
||||
result_queue.put({"node": node_rank, "success": False, "error": str(e)})
|
||||
raise e
|
||||
finally:
|
||||
if executor is not None:
|
||||
executor.shutdown()
|
||||
|
||||
# Create a queue to collect results from both processes
|
||||
result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue()
|
||||
|
||||
# Start both node processes
|
||||
processes = []
|
||||
for node_rank in range(2):
|
||||
p = multiprocessing.Process(
|
||||
target=run_node,
|
||||
args=(node_rank, result_queue, port),
|
||||
name=f"Node{node_rank}",
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
# Wait for both processes to complete
|
||||
all_completed = True
|
||||
for p in processes:
|
||||
p.join(timeout=60)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join(timeout=20)
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
all_completed = False
|
||||
|
||||
# Check results from both nodes
|
||||
results: list[dict[str, int | bool]] = []
|
||||
while len(results) < 2:
|
||||
try:
|
||||
result = result_queue.get(timeout=1)
|
||||
results.append(result)
|
||||
except Exception:
|
||||
pass
|
||||
assert all_completed, "Not all processes completed successfully"
|
||||
assert len(results) == 2, f"Expected 2 results, got {len(results)}"
|
||||
assert results[0]["success"], f"Node 0 failed: {results[0]}"
|
||||
assert results[1]["success"], f"Node 1 failed: {results[1]}"
|
||||
@ -210,6 +210,18 @@ class ParallelConfig:
|
||||
class is dynamically inherited by the worker class. This is used to inject
|
||||
new attributes and methods to the worker class for use in collective_rpc
|
||||
calls."""
|
||||
master_addr: str = "127.0.0.1"
|
||||
"""distributed master address for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
master_port: int = 29501
|
||||
"""distributed master port for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
node_rank: int = 0
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
|
||||
world_size: int = Field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
@ -387,6 +399,23 @@ class ParallelConfig:
|
||||
and self.data_parallel_size > 1
|
||||
)
|
||||
|
||||
@property
|
||||
def node_rank_within_dp(self) -> int:
|
||||
return self.node_rank % self.nnodes_within_dp
|
||||
|
||||
@property
|
||||
def nnodes_within_dp(self) -> int:
|
||||
if self.nnodes == 1:
|
||||
return 1
|
||||
data_parallel_node_size = (
|
||||
self.data_parallel_size // self.data_parallel_size_local
|
||||
)
|
||||
return self.nnodes // data_parallel_node_size
|
||||
|
||||
@property
|
||||
def local_world_size(self) -> int:
|
||||
return self.world_size // self.nnodes_within_dp
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
||||
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
||||
@ -528,6 +557,8 @@ class ParallelConfig:
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif current_platform.is_cuda() and self.nnodes > 1:
|
||||
backend = "mp"
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size
|
||||
@ -565,6 +596,10 @@ class ParallelConfig:
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed exectuor backend is mp."
|
||||
)
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
@ -607,6 +642,11 @@ class ParallelConfig:
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on current platform."
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce since we are running on multi-node."
|
||||
)
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError(
|
||||
"Unable to use nsight profiling unless workers run with Ray."
|
||||
|
||||
@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from pickle import PickleBuffer
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -602,13 +602,87 @@ class MessageQueue:
|
||||
return obj
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group_single_reader(
|
||||
pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
reader_rank: int = 0,
|
||||
blocking: bool = False,
|
||||
) -> tuple["MessageQueue", list[Handle]]:
|
||||
"""
|
||||
Creates a MessageQueue for a process group with a single reader.
|
||||
|
||||
This method is designed for scenarios where only one process (the reader)
|
||||
will consume messages, and all other processes are writers. It sets up
|
||||
the shared memory buffer and communication handles accordingly, and
|
||||
gathers the handles from all processes to the reader.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The torch distributed process group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
reader_rank (int, optional): The global rank that will act as the reader.
|
||||
Defaults to 0.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
tuple[MessageQueue, list[Handle]]:
|
||||
The MessageQueue instance for the calling process,
|
||||
and a list of handles (only non-empty for the reader process).
|
||||
"""
|
||||
local_size = torch.cuda.device_count()
|
||||
rank = dist.get_rank()
|
||||
same_node = rank // local_size == reader_rank // local_size
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=1,
|
||||
n_local_reader=1 if same_node else 0,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
|
||||
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io, cast(list[Handle], handles or [])
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
pg: ProcessGroup | StatelessProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0,
|
||||
writer_rank: int = 0,
|
||||
external_writer_handle=None,
|
||||
blocking: bool = True,
|
||||
) -> "MessageQueue":
|
||||
"""
|
||||
Creates a MessageQueue for a distributed process group with one writer and
|
||||
multiple readers.
|
||||
|
||||
This method is designed for scenarios where one process (the writer) sends
|
||||
messages, and all other processes (the readers) receive messages. It sets up
|
||||
the shared memory buffer and socket communication handles accordingly, and
|
||||
broadcasts the handle from the writer to all readers.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
|
||||
group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
writer_rank (int, optional): The global rank that will act as the writer.
|
||||
Defaults to 0.
|
||||
external_writer_handle (Handle, optional): Used when there is a handle
|
||||
from an external Message Queue. If provided, use this handle to init
|
||||
PG writer message queue instead of creating a new one. Defaults to None.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
MessageQueue: The MessageQueue instance for the calling process.
|
||||
|
||||
"""
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
@ -617,23 +691,26 @@ class MessageQueue:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
if external_writer_handle is not None:
|
||||
buffer_io = MessageQueue.create_from_handle(
|
||||
external_writer_handle, group_rank
|
||||
)
|
||||
else:
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list(
|
||||
@ -651,5 +728,6 @@ class MessageQueue:
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
|
||||
@ -385,6 +385,33 @@ class GroupCoordinator:
|
||||
torch.ops._C, "init_shm_manager"
|
||||
)
|
||||
|
||||
def create_mq_broadcaster(
|
||||
self, writer_rank=0, external_writer_handle=None, blocking=True
|
||||
):
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
|
||||
return MessageQueue.create_from_process_group(
|
||||
self.cpu_group,
|
||||
1 << 22,
|
||||
6,
|
||||
writer_rank=writer_rank,
|
||||
external_writer_handle=external_writer_handle,
|
||||
blocking=blocking,
|
||||
)
|
||||
|
||||
def create_single_reader_mq_broadcasters(
|
||||
self, reader_rank_in_group=0, blocking=False
|
||||
):
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
|
||||
return MessageQueue.create_from_process_group_single_reader(
|
||||
self.cpu_group,
|
||||
1 << 22,
|
||||
6,
|
||||
reader_rank=self.ranks[reader_rank_in_group],
|
||||
blocking=blocking,
|
||||
)
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@ -997,6 +1024,7 @@ class GroupCoordinator:
|
||||
|
||||
|
||||
_WORLD: GroupCoordinator | None = None
|
||||
_INNER_DP_WORLD: GroupCoordinator | None = None
|
||||
_NODE_COUNT: int | None = None
|
||||
|
||||
|
||||
@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
|
||||
return _WORLD
|
||||
|
||||
|
||||
def get_inner_dp_world_group() -> GroupCoordinator:
|
||||
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
|
||||
return _INNER_DP_WORLD
|
||||
|
||||
|
||||
def init_world_group(
|
||||
ranks: list[int], local_rank: int, backend: str
|
||||
) -> GroupCoordinator:
|
||||
@ -1023,12 +1056,13 @@ def init_model_parallel_group(
|
||||
backend: str,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
use_device_communicator: bool = True,
|
||||
) -> GroupCoordinator:
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=True,
|
||||
use_device_communicator=use_device_communicator,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
@ -1143,7 +1177,14 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
if (
|
||||
if config is not None and config.parallel_config.nnodes > 1:
|
||||
parallel_config = config.parallel_config
|
||||
ip = parallel_config.master_addr
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
elif (
|
||||
config is not None
|
||||
and config.parallel_config.data_parallel_size > 1
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@ -1164,6 +1205,14 @@ def init_distributed_environment(
|
||||
distributed_init_method,
|
||||
)
|
||||
if not torch.distributed.is_initialized():
|
||||
logger.info(
|
||||
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
|
||||
world_size,
|
||||
rank,
|
||||
local_rank,
|
||||
distributed_init_method,
|
||||
backend,
|
||||
)
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
"distributed environment"
|
||||
@ -1192,16 +1241,36 @@ def init_distributed_environment(
|
||||
# local rank not set, this usually happens in single-node
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
if config.parallel_config.nnodes > 1:
|
||||
_NODE_COUNT = config.parallel_config.nnodes
|
||||
else:
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
|
||||
else:
|
||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||
"world group already initialized with a different world size"
|
||||
)
|
||||
if config.parallel_config.nnodes_within_dp > 1:
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
world_size_inner_dp = parallel_config.world_size
|
||||
group_ranks = [
|
||||
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
|
||||
for dp_rank in range(parallel_config.data_parallel_size)
|
||||
]
|
||||
_INNER_DP_WORLD = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="inner_dp_world",
|
||||
use_device_communicator=False,
|
||||
)
|
||||
else:
|
||||
_INNER_DP_WORLD = _WORLD
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
|
||||
@ -384,6 +384,10 @@ class EngineArgs:
|
||||
) = ParallelConfig.distributed_executor_backend
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
master_addr: str = ParallelConfig.master_addr
|
||||
master_port: int = ParallelConfig.master_port
|
||||
nnodes: int = ParallelConfig.nnodes
|
||||
node_rank: int = ParallelConfig.node_rank
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||
@ -394,6 +398,7 @@ class EngineArgs:
|
||||
data_parallel_address: str | None = None
|
||||
data_parallel_rpc_port: int | None = None
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
data_parallel_external_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||
@ -749,6 +754,10 @@ class EngineArgs:
|
||||
"-pp",
|
||||
**parallel_kwargs["pipeline_parallel_size"],
|
||||
)
|
||||
parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
|
||||
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
|
||||
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
|
||||
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
|
||||
parallel_group.add_argument(
|
||||
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
|
||||
)
|
||||
@ -803,7 +812,14 @@ class EngineArgs:
|
||||
help='Backend for data parallel, either "mp" or "ray".',
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
|
||||
"--data-parallel-hybrid-lb",
|
||||
"-dph",
|
||||
**parallel_kwargs["data_parallel_hybrid_lb"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-external-lb",
|
||||
"-dpe",
|
||||
**parallel_kwargs["data_parallel_external_lb"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
|
||||
@ -1428,12 +1444,56 @@ class EngineArgs:
|
||||
assert not headless or not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_hybrid_lb is not applicable in headless mode"
|
||||
)
|
||||
|
||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
||||
assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
|
||||
"data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
|
||||
)
|
||||
assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
|
||||
"nnodes > 1 is only supported with data_parallel_backend=mp"
|
||||
)
|
||||
inferred_data_parallel_rank = 0
|
||||
if self.nnodes > 1:
|
||||
world_size = (
|
||||
self.data_parallel_size
|
||||
* self.pipeline_parallel_size
|
||||
* self.tensor_parallel_size
|
||||
)
|
||||
world_size_within_dp = (
|
||||
self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
)
|
||||
local_world_size = world_size // self.nnodes
|
||||
assert world_size % self.nnodes == 0, (
|
||||
f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
|
||||
)
|
||||
assert self.node_rank < self.nnodes, (
|
||||
f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
|
||||
)
|
||||
inferred_data_parallel_rank = (
|
||||
self.node_rank * local_world_size
|
||||
) // world_size_within_dp
|
||||
if self.data_parallel_size > 1 and self.data_parallel_external_lb:
|
||||
self.data_parallel_rank = inferred_data_parallel_rank
|
||||
logger.info(
|
||||
"Inferred data_parallel_rank %d from node_rank %d for external lb",
|
||||
self.data_parallel_rank,
|
||||
self.node_rank,
|
||||
)
|
||||
elif self.data_parallel_size_local is None:
|
||||
# Infer data parallel size local for internal dplb:
|
||||
self.data_parallel_size_local = max(
|
||||
local_world_size // world_size_within_dp, 1
|
||||
)
|
||||
data_parallel_external_lb = (
|
||||
self.data_parallel_external_lb or self.data_parallel_rank is not None
|
||||
)
|
||||
# Local DP rank = 1, use pure-external LB.
|
||||
if data_parallel_external_lb:
|
||||
assert self.data_parallel_rank is not None, (
|
||||
"data_parallel_rank or node_rank must be spefified if "
|
||||
"data_parallel_external_lb is enable."
|
||||
)
|
||||
assert self.data_parallel_size_local in (1, None), (
|
||||
"data_parallel_size_local must be 1 when data_parallel_rank is set"
|
||||
"data_parallel_size_local must be 1 or None when data_parallel_rank "
|
||||
"is set"
|
||||
)
|
||||
data_parallel_size_local = 1
|
||||
# Use full external lb if we have local_size of 1.
|
||||
@ -1447,6 +1507,11 @@ class EngineArgs:
|
||||
|
||||
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
|
||||
# Use full external lb if we have local_size of 1.
|
||||
logger.warning(
|
||||
"data_parallel_hybrid_lb is not eligible when "
|
||||
"data_parallel_size_local = 1, autoswitch to "
|
||||
"data_parallel_external_lb."
|
||||
)
|
||||
data_parallel_external_lb = True
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
@ -1454,7 +1519,15 @@ class EngineArgs:
|
||||
# Disable hybrid LB mode if set for a single node
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
self.data_parallel_rank = self.data_parallel_start_rank or 0
|
||||
self.data_parallel_rank = (
|
||||
self.data_parallel_start_rank or inferred_data_parallel_rank
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
logger.info(
|
||||
"Inferred data_parallel_rank %d from node_rank %d",
|
||||
self.data_parallel_rank,
|
||||
self.node_rank,
|
||||
)
|
||||
else:
|
||||
assert not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
|
||||
@ -1484,7 +1557,9 @@ class EngineArgs:
|
||||
"data_parallel_backend can only be ray or mp, got %s",
|
||||
self.data_parallel_backend,
|
||||
)
|
||||
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
||||
data_parallel_address = (
|
||||
self.master_addr or ParallelConfig.data_parallel_master_ip
|
||||
)
|
||||
else:
|
||||
data_parallel_address = self.data_parallel_address
|
||||
|
||||
@ -1517,6 +1592,10 @@ class EngineArgs:
|
||||
data_parallel_rank=self.data_parallel_rank or 0,
|
||||
data_parallel_external_lb=data_parallel_external_lb,
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
master_addr=self.master_addr,
|
||||
master_port=self.master_port,
|
||||
nnodes=self.nnodes,
|
||||
node_rank=self.node_rank,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
||||
|
||||
@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace):
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
shutdown_requested = False
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
logger.debug("Received %d signal.", signum)
|
||||
raise SystemExit
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
if parallel_config.node_rank_within_dp > 0:
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
# Run headless workers (for multi-node PP/TP).
|
||||
host = parallel_config.master_addr
|
||||
head_node_address = f"{host}:{parallel_config.master_port}"
|
||||
logger.info(
|
||||
"Launching vLLM (v%s) headless multiproc executor, "
|
||||
"with head node address %s for torch.distributed process group.",
|
||||
VLLM_VERSION,
|
||||
head_node_address,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
|
||||
executor.start_worker_monitor(inline=True)
|
||||
return
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.data_parallel_rpc_port
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.",
|
||||
|
||||
@ -183,15 +183,19 @@ def set_device_control_env_var(
|
||||
for engine subprocess.
|
||||
"""
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
local_world_size = vllm_config.parallel_config.local_world_size
|
||||
evar = current_platform.device_control_env_var
|
||||
|
||||
value = get_device_indices(evar, local_dp_rank, world_size)
|
||||
value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
|
||||
with patch.dict(os.environ, values=((evar, value),)):
|
||||
yield
|
||||
|
||||
|
||||
def get_device_indices(
|
||||
device_control_env_var: str, local_dp_rank: int, world_size: int
|
||||
device_control_env_var: str,
|
||||
local_dp_rank: int,
|
||||
world_size: int,
|
||||
local_world_size: int | None = None,
|
||||
):
|
||||
"""
|
||||
Returns a comma-separated string of device indices for the specified
|
||||
@ -200,10 +204,15 @@ def get_device_indices(
|
||||
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
|
||||
this will select devices 2 and 3 for local_dp_rank=1.
|
||||
"""
|
||||
if local_world_size is None:
|
||||
local_world_size = world_size
|
||||
try:
|
||||
value = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
|
||||
for i in range(
|
||||
local_dp_rank * world_size,
|
||||
local_dp_rank * world_size + local_world_size,
|
||||
)
|
||||
)
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
|
||||
@ -10,7 +10,7 @@ import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import Future, InvalidStateError
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_inner_dp_world_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
@ -90,6 +91,10 @@ class FutureWrapper(Future):
|
||||
class MultiprocExecutor(Executor):
|
||||
supports_pp: bool = True
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
|
||||
self.monitor_workers = monitor_workers
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||
f"divisible by nnodes_within_dp "
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port()
|
||||
)
|
||||
|
||||
self.rpc_broadcast_mq: MessageQueue | None = None
|
||||
scheduler_output_handle: Handle | None = None
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
# For leader node within each dp rank,
|
||||
# each dp will have its own leader multiproc executor.
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
connect_ip=self.parallel_config.master_addr,
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
global_start_rank = (
|
||||
self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||
)
|
||||
for local_rank in range(self.local_world_size):
|
||||
global_rank = global_start_rank + local_rank
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
rank=global_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
|
||||
# Wait for all local workers to be ready.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Start background thread to monitor worker health if not in headless mode.
|
||||
if self.monitor_workers:
|
||||
self.start_worker_monitor()
|
||||
|
||||
self.response_mqs = []
|
||||
# Only leader node have remote response mqs
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
for rank in range(self.world_size):
|
||||
if rank < self.local_world_size:
|
||||
local_message_queue = self.workers[rank].worker_response_mq
|
||||
assert local_message_queue is not None
|
||||
self.response_mqs.append(local_message_queue)
|
||||
else:
|
||||
remote_message_queue = self.workers[0].peer_worker_response_mqs[
|
||||
rank
|
||||
]
|
||||
assert remote_message_queue is not None
|
||||
self.response_mqs.append(remote_message_queue)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
# Wait for all input mqs to be ready.
|
||||
if self.rpc_broadcast_mq is not None:
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
# Wait for all remote response mqs to be ready.
|
||||
for response_mq in self.response_mqs:
|
||||
response_mq.wait_until_ready()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def start_worker_monitor(self):
|
||||
def start_worker_monitor(self, inline=False) -> None:
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
if not inline:
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
return
|
||||
|
||||
monitor_workers()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
|
||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||
is provided, otherwise list."""
|
||||
|
||||
assert self.rpc_broadcast_mq is not None, (
|
||||
"collective_rpc should not be called on follower node"
|
||||
)
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
|
||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||
|
||||
workers = (
|
||||
(self.workers[output_rank],) if output_rank is not None else self.workers
|
||||
)
|
||||
response_mqs: Sequence[MessageQueue] = self.response_mqs
|
||||
if output_rank is not None:
|
||||
response_mqs = (response_mqs[output_rank],)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
|
||||
def get_response():
|
||||
responses = []
|
||||
for w in workers:
|
||||
for mq in response_mqs:
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
try:
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
status, result = mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=shutdown_event
|
||||
)
|
||||
except TimeoutError as e:
|
||||
@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
# The worker process writes to this MQ in single-node mode
|
||||
worker_response_mq: MessageQueue | None
|
||||
# This is only non empty on driver node,
|
||||
# the peer worker process i writes to MQ
|
||||
# `peer_worker_response_mqs[i]`
|
||||
peer_worker_response_mqs: list[MessageQueue | None]
|
||||
death_writer: Connection | None = None
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
|
||||
cls,
|
||||
unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue | None,
|
||||
peer_worker_response_mqs: list[MessageQueue | None],
|
||||
) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
death_writer=unready_handle.death_writer,
|
||||
)
|
||||
|
||||
@ -411,6 +470,38 @@ class WorkerProc:
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def _init_message_queues(
|
||||
self, input_shm_handle: Handle, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
if vllm_config.parallel_config.nnodes_within_dp == 1:
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
|
||||
self.peer_response_handles = []
|
||||
else:
|
||||
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
|
||||
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
|
||||
external_writer_handle=input_shm_handle,
|
||||
# Since there is external_writer_handle from executor proc,
|
||||
# where the ready signal from actual writer is sent out of the
|
||||
# create_mq_broadcaster method and after this setup, we make it
|
||||
# non blocking. The handshake will be triggered when
|
||||
# worker.rpc_broadcast_mq.wait_until_ready() is called
|
||||
blocking=False,
|
||||
)
|
||||
# Initializes remote message queue for sending the model output to the
|
||||
# driver worker, exposing peer_response_handles for driver worker
|
||||
# that include handles for all ranks
|
||||
self.worker_response_mq, self.peer_response_handles = (
|
||||
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
|
||||
reader_rank_in_group=0
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -421,13 +512,15 @@ class WorkerProc:
|
||||
shared_worker_lock: LockType,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
wrapper = WorkerWrapperBase(
|
||||
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
|
||||
)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||
all_kwargs[rank] = {
|
||||
all_kwargs[local_rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
@ -438,14 +531,6 @@ class WorkerProc:
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
@ -466,6 +551,7 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
# Load model
|
||||
self._init_message_queues(input_shm_handle, vllm_config)
|
||||
self.worker.load_model()
|
||||
|
||||
# Enable environment variable cache (e.g. assume no more
|
||||
@ -512,6 +598,27 @@ class WorkerProc:
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_response_handle_ready(
|
||||
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
|
||||
) -> WorkerProcHandle:
|
||||
response_handle = handles["handle"]
|
||||
worker_response_mq: MessageQueue | None = None
|
||||
if len(response_handle.local_reader_ranks) > 0:
|
||||
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
|
||||
peer_response_handles = handles["peer_response_handles"]
|
||||
peer_worker_response_mqs = [
|
||||
MessageQueue.create_from_handle(handle, -1)
|
||||
if handle.remote_subscribe_addr is not None
|
||||
else None
|
||||
for handle in peer_response_handles
|
||||
]
|
||||
return WorkerProcHandle.from_unready_handle(
|
||||
proc_handle,
|
||||
worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
||||
@ -537,16 +644,10 @@ class WorkerProc:
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0
|
||||
idx = unready_proc_handle.rank % len(ready_proc_handles)
|
||||
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
|
||||
response, unready_proc_handle
|
||||
)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq
|
||||
)
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
@ -618,12 +719,14 @@ class WorkerProc:
|
||||
{
|
||||
"status": WorkerProc.READY_STR,
|
||||
"handle": worker.worker_response_mq.export_handle(),
|
||||
"peer_response_handles": worker.peer_response_handles,
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
if worker.rpc_broadcast_mq is not None:
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
@ -189,6 +189,7 @@ class Worker(WorkerBase):
|
||||
and self.parallel_config.distributed_executor_backend
|
||||
not in ["ray", "external_launcher"]
|
||||
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
||||
and self.vllm_config.parallel_config.nnodes_within_dp == 1
|
||||
):
|
||||
# Use local DP rank if available, otherwise use global DP rank.
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||
@ -205,7 +206,14 @@ class Worker(WorkerBase):
|
||||
assert self.local_rank < torch.cuda.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
|
||||
visible_device_count = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
)
|
||||
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||
f"local_world_size ({self.parallel_config.local_world_size}) must be "
|
||||
f"less than or equal to the number of visible devices "
|
||||
f"({visible_device_count})."
|
||||
)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
|
||||
@ -180,6 +180,7 @@ class WorkerWrapperBase:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rpc_rank: int = 0,
|
||||
global_rank: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||
@ -192,6 +193,7 @@ class WorkerWrapperBase:
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.global_rank = self.rpc_rank if global_rank is None else global_rank
|
||||
self.worker: WorkerBase | None = None
|
||||
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
@ -312,7 +314,7 @@ class WorkerWrapperBase:
|
||||
assert self.worker is not None
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
||||
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user